Zero redundancy optimizer(ZeRO) on MindOne¶
Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models.
ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency.
This tutorial walks you through how to generate faster and better with the ZeRO on MindOne.
Build Train Network With ZeRO¶
Build a train network with ZeRO.
import mindspore as ms
from mindspore.communication import init
from mindspore.communication.management import GlobalComm
from mindone.trainers.zero import prepare_train_network
# Initialize distributed environment
def init_env(mode, distribute):
ms.set_context(mode=mode)
if distribute:
init()
# ZeRO take effect must on DATA_PARALLEL
ms.set_auto_parallel_context(
parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)
init_env(ms.GRAPH_MODE, True)
# Net is your Train Network
net = Net()
# opt must be the subclass of MindSpore Optimizer.
opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3)
# build a train network with ZeRO
train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
Tip
optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using create_group to create your optimizer_parallel_group.
More details:
mindone.trainers.zero.prepare_train_network(network, optimizer, scale_sense=1.0, ema=None, updates=0, drop_overflow_update=True, gradient_accumulation_steps=1, clip_grad=False, clip_norm=1.0, verbose=False, zero_stage=0, optimizer_offload=False, optimizer_parallel_group=None, dp_group=None, comm_fusion=None, parallel_modules=None)
¶
Prepare network and optimizer for distributed training.
PARAMETER | DESCRIPTION |
---|---|
network |
train network, not include grad function, grad function must be built after rewrite train network.
TYPE:
|
optimizer |
Must be the subclass of MindSpore Optimizer.
TYPE:
|
scale_sense |
If this value is a Cell, it will be called
to update loss scale. If this value is a Tensor, the loss scale can be modified by
TYPE:
|
zero_stage |
Stage setting of ZeRO, default is 0.
TYPE:
|
optimizer_offload |
Only take effect when optimizer is AdamWeightDecay, default is False.
TYPE:
|
optimizer_parallel_group |
The name of the optimizer parallel communication group, default is None.
TYPE:
|
dp_group |
The name of the data parallel communication group, default is None.
TYPE:
|
comm_fusion |
A dict contains the types and configurations for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, turn on the communication fusion. Examples: {"allreduce": {"openstate": True, "bucket_size": 5e8}, "reduce_scatter": {"openstate": True, "bucket_size": 5e8}, "allgather": {"openstate": False, "bucket_size": 5e8},}
TYPE:
|
parallel_modules |
A dict of Cells could split parameters in zero3, default is None.
If None, use
TYPE:
|
Source code in mindone/trainers/zero.py
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 |
|
Here is an example.
Memory Analysis¶
The memory consumption during the training can be divided into two main parts:
- Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments.
- Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ, the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%.
Residual states can be greatly reduced through recompute and model parallel. Then the ZeRO algorithm can be used to reduce model states.
For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data.
ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively:
1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the ¼ original memory; 2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the ⅛ original memory; 3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0;
Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3.
Communitition Analysis¶
Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ.
zero stage | forward + backward | gradient | optimizer update | communitition |
---|---|---|---|---|
0 | NA | AllReduce | NA | 2Φ |
1 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
2 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
3 | 2 AllGather | ReduceScatter | NA | 3Φ |
It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster.
CheckPoint Saving & Loading¶
Because the parameters of the model have been split, the parameters of each card need to be saved.
Resume¶
checkpoint save:
zero stage | parameters | optimizer states | ema |
---|---|---|---|
0 | one card | one card | one card |
1 | one card | each card | each card |
2 | one card | each card | each card |
3 | each card | each card | each card |
Tip
💡 Recommend using rank_id to distinguish checkpoint saved on different cards.
rank_id = get_rank_id()
zero_stage=2
train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
if resume:
network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt"
ms.load_checkpoint(network_ckpt, net=train_net.network)
optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt"
ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer)
ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt"
ms.load_checkpoint(ema_ckpt, net=train_net.ema)
Inference¶
Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters.
Online Checkpoint Combile¶
def do_ckpt_combine_online(net_to_save, optimizer_parallel_group):
new_net_to_save = []
all_gather_op = ops.AllGather(optimizer_parallel_group)
for item in self.net_to_save:
param = item["data"]
if param.parallel_optimizer:
new_data = ms.Tensor(all_gather_op(param).asnumpy())
else:
new_data = ms.Tensor(param.asnumpy())
new_net_to_save.append({"name": param.name, "data": new_data})
return new_net_to_save
net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()]
net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group)
ms.save_checkpoint(net_to_save, "network.ckpt")
Add the code when need save model parameters.
Offline Checkpoint Combile¶
Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline.
from mindone.trainers.zero import convert_checkpoints
src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt"
src_param_split_info_json = "params_info/params_split_info_{}.json"
group_size = 2
convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size)
And get the complete model parameters checkpoint at save_checkpoint_dir/ckpt_all_2.ckpt
.