Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
@dataclassclassParallelConfig:""" Configuration for applying different parallelisms. Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. """context_parallel_config:Optional[ContextParallelConfig]=None_rank:int=None_world_size:int=None_device:str=None_cp_mesh:dict=Nonedefsetup(self,rank:int,world_size:int,device:str,*,cp_mesh:Optional[dict]=None,):self._rank=rankself._world_size=world_sizeself._device=deviceself._cp_mesh=cp_meshifself.context_parallel_configisnotNone:self.context_parallel_config.setup(rank,world_size,device,cp_mesh)
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
TYPE:`int`, *optional*, defaults to `1`DEFAULT:None
ulysses_degree
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
TYPE:`int`, *optional*, defaults to `1`DEFAULT:None
convert_to_fp32
Whether to convert output and LSE to float32 for ring attention numerical stability.
TYPE:`bool`, *optional*, defaults to `True`DEFAULT:True
rotate_method
Method to use for rotating key/value states across devices in ring attention. Currently, only "allgather"
is supported.
TYPE:`str`, *optional*, defaults to `"allgather"`DEFAULT:'allgather'
Source code in mindone/diffusers/models/_modeling_parallel.py
@dataclassclassContextParallelConfig:""" Configuration for context parallelism. Args: ring_degree (`int`, *optional*, defaults to `1`): Number of devices to use for ring attention within a context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. ulysses_degree (`int`, *optional*, defaults to `1`): Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. convert_to_fp32 (`bool`, *optional*, defaults to `True`): Whether to convert output and LSE to float32 for ring attention numerical stability. rotate_method (`str`, *optional*, defaults to `"allgather"`): Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` is supported. """ring_degree:Optional[int]=Noneulysses_degree:Optional[int]=Noneconvert_to_fp32:bool=True# TODO: support alltoallrotate_method:Literal["allgather","alltoall"]="allgather"_rank:int=None_world_size:int=None_device:str=None_mesh:dict=None_flattened_mesh:str=None_ring_mesh:str=None_ulysses_mesh:str=None_ring_local_rank:int=None_ulysses_local_rank:int=Nonedef__post_init__(self):ifself.ring_degreeisNone:self.ring_degree=1ifself.ulysses_degreeisNone:self.ulysses_degree=1defsetup(self,rank:int,world_size:int,device,mesh):self._rank=rankself._world_size=world_sizeself._device=deviceself._mesh=meshifself.ring_degreeisNone:self.ring_degree=1ifself.ulysses_degreeisNone:self.ulysses_degree=1ifself.rotate_method!="allgather":raiseNotImplementedError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.")ifself._flattened_meshisNone:self._flattened_mesh=self._mesh._flatten()ifself._ring_meshisNone:self._ring_mesh=self._mesh["ring"]ifself._ulysses_meshisNone:self._ulysses_mesh=self._mesh["ulysses"]ifself._ring_local_rankisNone:self._ring_local_rank=mint.distributed.get_rank(self._ring_mesh)ifself._ulysses_local_rankisNone:self._ulysses_local_rank=mint.distributed.get_rank(self._ulysses_mesh)
defapply_context_parallel(module:ms.nn.Cell,parallel_config:ContextParallelConfig,plan:Dict[str,ContextParallelModelPlan],)->None:"""Apply context parallel on a model."""logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")formodule_id,cp_model_planinplan.items():submodule=_get_submodule_by_name(module,module_id)ifnotisinstance(submodule,list):submodule=[submodule]logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")forminsubmodule:ifisinstance(cp_model_plan,dict):hook=ContextParallelSplitHook(cp_model_plan,parallel_config)hook_name=_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)elifisinstance(cp_model_plan,(ContextParallelOutput,list,tuple)):ifisinstance(cp_model_plan,ContextParallelOutput):cp_model_plan=[cp_model_plan]ifnotall(isinstance(x,ContextParallelOutput)forxincp_model_plan):raiseValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")hook=ContextParallelGatherHook(cp_model_plan,parallel_config)hook_name=_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)else:raiseValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")registry=HookRegistry.check_if_exists_or_initialize(m)registry.register_hook(hook,hook_name)