Skip to content

Parallelism

Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.

mindone.diffusers.ParallelConfig dataclass

Configuration for applying different parallelisms.

PARAMETER DESCRIPTION
context_parallel_config

Configuration for context parallelism.

TYPE: `ContextParallelConfig`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/_modeling_parallel.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@dataclass
class ParallelConfig:
    """
    Configuration for applying different parallelisms.

    Args:
        context_parallel_config (`ContextParallelConfig`, *optional*):
            Configuration for context parallelism.
    """

    context_parallel_config: Optional[ContextParallelConfig] = None

    _rank: int = None
    _world_size: int = None
    _device: str = None
    _cp_mesh: dict = None

    def setup(
        self,
        rank: int,
        world_size: int,
        device: str,
        *,
        cp_mesh: Optional[dict] = None,
    ):
        self._rank = rank
        self._world_size = world_size
        self._device = device
        self._cp_mesh = cp_mesh
        if self.context_parallel_config is not None:
            self.context_parallel_config.setup(rank, world_size, device, cp_mesh)

mindone.diffusers.ContextParallelConfig dataclass

Configuration for context parallelism.

PARAMETER DESCRIPTION
ring_degree

Number of devices to use for ring attention within a context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.

TYPE: `int`, *optional*, defaults to `1` DEFAULT: None

ulysses_degree

Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.

TYPE: `int`, *optional*, defaults to `1` DEFAULT: None

convert_to_fp32

Whether to convert output and LSE to float32 for ring attention numerical stability.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: True

rotate_method

Method to use for rotating key/value states across devices in ring attention. Currently, only "allgather" is supported.

TYPE: `str`, *optional*, defaults to `"allgather"` DEFAULT: 'allgather'

Source code in mindone/diffusers/models/_modeling_parallel.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@dataclass
class ContextParallelConfig:
    """
    Configuration for context parallelism.

    Args:
        ring_degree (`int`, *optional*, defaults to `1`):
            Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
            total number of devices in the context parallel mesh.
        ulysses_degree (`int`, *optional*, defaults to `1`):
            Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
            total number of devices in the context parallel mesh.
        convert_to_fp32 (`bool`, *optional*, defaults to `True`):
            Whether to convert output and LSE to float32 for ring attention numerical stability.
        rotate_method (`str`, *optional*, defaults to `"allgather"`):
            Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
            is supported.

    """

    ring_degree: Optional[int] = None
    ulysses_degree: Optional[int] = None
    convert_to_fp32: bool = True
    # TODO: support alltoall
    rotate_method: Literal["allgather", "alltoall"] = "allgather"

    _rank: int = None
    _world_size: int = None
    _device: str = None
    _mesh: dict = None
    _flattened_mesh: str = None
    _ring_mesh: str = None
    _ulysses_mesh: str = None
    _ring_local_rank: int = None
    _ulysses_local_rank: int = None

    def __post_init__(self):
        if self.ring_degree is None:
            self.ring_degree = 1
        if self.ulysses_degree is None:
            self.ulysses_degree = 1

    def setup(self, rank: int, world_size: int, device, mesh):
        self._rank = rank
        self._world_size = world_size
        self._device = device
        self._mesh = mesh
        if self.ring_degree is None:
            self.ring_degree = 1
        if self.ulysses_degree is None:
            self.ulysses_degree = 1
        if self.rotate_method != "allgather":
            raise NotImplementedError(
                f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
            )
        if self._flattened_mesh is None:
            self._flattened_mesh = self._mesh._flatten()
        if self._ring_mesh is None:
            self._ring_mesh = self._mesh["ring"]
        if self._ulysses_mesh is None:
            self._ulysses_mesh = self._mesh["ulysses"]
        if self._ring_local_rank is None:
            self._ring_local_rank = mint.distributed.get_rank(self._ring_mesh)
        if self._ulysses_local_rank is None:
            self._ulysses_local_rank = mint.distributed.get_rank(self._ulysses_mesh)

mindone.diffusers.hooks.apply_context_parallel(module, parallel_config, plan)

Apply context parallel on a model.

Source code in mindone/diffusers/hooks/context_parallel.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def apply_context_parallel(
    module: ms.nn.Cell,
    parallel_config: ContextParallelConfig,
    plan: Dict[str, ContextParallelModelPlan],
) -> None:
    """Apply context parallel on a model."""
    logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")

    for module_id, cp_model_plan in plan.items():
        submodule = _get_submodule_by_name(module, module_id)
        if not isinstance(submodule, list):
            submodule = [submodule]

        logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")

        for m in submodule:
            if isinstance(cp_model_plan, dict):
                hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
                hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
            elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
                if isinstance(cp_model_plan, ContextParallelOutput):
                    cp_model_plan = [cp_model_plan]
                if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
                    raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
                hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
                hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
            else:
                raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
            registry = HookRegistry.check_if_exists_or_initialize(m)
            registry.register_hook(hook, hook_name)