Skip to content

Zero redundancy optimizer(ZeRO) on MindOne

Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models.

ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency.

This tutorial walks you through how to generate faster and better with the ZeRO on MindOne.

Build Train Network With ZeRO

Build a train network with ZeRO.

import mindspore as ms
from mindspore.communication import init
from mindspore.communication.management import GlobalComm
from mindone.trainers.zero import prepare_train_network

# Initialize distributed environment
def init_env(mode, distribute):
    ms.set_context(mode=mode)
    if distribute:
        init()
        # ZeRO take effect must on DATA_PARALLEL
        ms.set_auto_parallel_context(
            parallel_mode=ms.ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
        )

init_env(ms.GRAPH_MODE, True)

# Net is your Train Network
net = Net()
# opt must be the subclass of MindSpore Optimizer.
opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3)

# build a train network with ZeRO
train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)

Tip

optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using create_group to create your optimizer_parallel_group.

More details:

mindone.trainers.zero.prepare_train_network(network, optimizer, scale_sense=1.0, ema=None, updates=0, drop_overflow_update=True, gradient_accumulation_steps=1, clip_grad=False, clip_norm=1.0, verbose=False, zero_stage=0, optimizer_offload=False, optimizer_parallel_group=None, dp_group=None, comm_fusion=None, parallel_modules=None)

Prepare network and optimizer for distributed training.

PARAMETER DESCRIPTION
network

train network, not include grad function, grad function must be built after rewrite train network.

TYPE: `nn.Cell`

optimizer

Must be the subclass of MindSpore Optimizer.

TYPE: `nn.Optimizer`

scale_sense

If this value is a Cell, it will be called to update loss scale. If this value is a Tensor, the loss scale can be modified by set_sense_scale, the shape should be :math:() or :math:(1,).

TYPE: Union[Tensor, Cell] DEFAULT: 1.0

zero_stage

Stage setting of ZeRO, default is 0.

TYPE: `int`, *optional* DEFAULT: 0

optimizer_offload

Only take effect when optimizer is AdamWeightDecay, default is False.

TYPE: `bool`, *optional* DEFAULT: False

optimizer_parallel_group

The name of the optimizer parallel communication group, default is None.

TYPE: `str`, *optional* DEFAULT: None

dp_group

The name of the data parallel communication group, default is None.

TYPE: `str`, *optional* DEFAULT: None

comm_fusion

A dict contains the types and configurations for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, turn on the communication fusion. Examples: {"allreduce": {"openstate": True, "bucket_size": 5e8}, "reduce_scatter": {"openstate": True, "bucket_size": 5e8}, "allgather": {"openstate": False, "bucket_size": 5e8},}

TYPE: `dict`, *optional* DEFAULT: None

parallel_modules

A dict of Cells could split parameters in zero3, default is None. If None, use PARALLEL_MODULES from mindone.models.modules.parallel.

TYPE: `dict`, *optional* DEFAULT: None

Source code in mindone/trainers/zero.py
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def prepare_train_network(
    network: nn.Cell,
    optimizer: nn.Optimizer,
    scale_sense: float = 1.0,
    ema: nn.Cell = None,
    updates: int = 0,
    drop_overflow_update: bool = True,
    gradient_accumulation_steps: int = 1,
    clip_grad: bool = False,
    clip_norm: float = 1.0,
    verbose: bool = False,
    zero_stage: Literal[0, 1, 2, 3] = 0,
    optimizer_offload: bool = False,
    optimizer_parallel_group: str = None,
    dp_group: str = None,
    comm_fusion: dict = None,
    parallel_modules=None,
) -> TrainOneStepWrapper:
    """
    Prepare network and optimizer for distributed training.

    Args:
        network (`nn.Cell`): train network, not include grad function,
            grad function must be built after rewrite train network.
        optimizer (`nn.Optimizer`): Must be the subclass of MindSpore Optimizer.
        scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
            to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
            the shape should be :math:`()` or :math:`(1,)`.
        zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0.
        optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False.
        optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None.
        dp_group (`str`, *optional*): The name of the data parallel communication group, default is None.
        comm_fusion (`dict`, *optional*): A dict contains the types and configurations
            for setting the communication fusion, default is None, turn off the communication fusion. If set a dict,
            turn on the communication fusion.
            Examples: {"allreduce": {"openstate": True, "bucket_size": 5e8},
                       "reduce_scatter": {"openstate": True, "bucket_size": 5e8},
                       "allgather": {"openstate": False, "bucket_size": 5e8},}
        parallel_modules (`dict`, *optional*): A dict of Cells could split parameters in zero3, default is None.
            If None, use `PARALLEL_MODULES` from `mindone.models.modules.parallel`.
    """
    if zero_stage not in [0, 1, 2, 3]:
        raise ValueError("Not support zero_stage {zero_stage}")
    if optimizer_parallel_group is None:
        _logger.warning("Not set zero group, set it WORLD_COMM_GROUP.")
        optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP
    if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None:
        raise ValueError(
            "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage"
        )

    is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
    if not is_parallel and zero_stage == 0:
        _logger.info("No need prepare train_network with zero.")
        zero_helper = None
    else:
        network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules)
        zero_helper = ZeroHelper(
            optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion
        )

    if ema is not None:
        ema = prepare_ema(ema, zero_stage, optimizer_parallel_group)
    if isinstance(scale_sense, float):
        scale_sense = ms.Tensor(scale_sense, ms.float32)
    train_network = TrainOneStepWrapper(
        network,
        optimizer,
        scale_sense=scale_sense,
        ema=ema,
        updates=updates,
        drop_overflow_update=drop_overflow_update,
        gradient_accumulation_steps=gradient_accumulation_steps,
        clip_grad=clip_grad,
        clip_norm=clip_norm,
        verbose=verbose,
        zero_helper=zero_helper,
    )
    return train_network

Here is an example.

Memory Analysis

The memory consumption during the training can be divided into two main parts:

  • Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments.
  • Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ, the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%.

Residual states can be greatly reduced through recompute and model parallel. Then the ZeRO algorithm can be used to reduce model states.

For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data.

ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively:

1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the ¼ original memory; 2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the ⅛ original memory; 3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0;

Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3.

Communitition Analysis

Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ.

zero stage forward + backward gradient optimizer update communitition
0 NA AllReduce NA 2Φ
1 NA 1/N ReduceScatter 1/N AllGather 2Φ
2 NA 1/N ReduceScatter 1/N AllGather 2Φ
3 2 AllGather ReduceScatter NA 3Φ

It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster.

CheckPoint Saving & Loading

Because the parameters of the model have been split, the parameters of each card need to be saved.

Resume

checkpoint save:

zero stage parameters optimizer states ema
0 one card one card one card
1 one card each card each card
2 one card each card each card
3 each card each card each card

Tip

💡 Recommend using rank_id to distinguish checkpoint saved on different cards.

rank_id = get_rank_id()
zero_stage=2
train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
if resume:
    network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt"
    ms.load_checkpoint(network_ckpt, net=train_net.network)
    optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt"
    ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer)
    ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt"
    ms.load_checkpoint(ema_ckpt, net=train_net.ema)

Inference

Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters.

Online Checkpoint Combile

def do_ckpt_combine_online(net_to_save, optimizer_parallel_group):
    new_net_to_save = []
    all_gather_op = ops.AllGather(optimizer_parallel_group)
    for item in self.net_to_save:
        param = item["data"]
        if param.parallel_optimizer:
            new_data = ms.Tensor(all_gather_op(param).asnumpy())
        else:
            new_data = ms.Tensor(param.asnumpy())
        new_net_to_save.append({"name": param.name, "data": new_data})
    return new_net_to_save

net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()]
net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group)
ms.save_checkpoint(net_to_save, "network.ckpt")

Add the code when need save model parameters.

Offline Checkpoint Combile

Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline.

from mindone.trainers.zero import convert_checkpoints

src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt"
src_param_split_info_json = "params_info/params_split_info_{}.json"
group_size = 2
convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size)

And get the complete model parameters checkpoint at save_checkpoint_dir/ckpt_all_2.ckpt.