Skip to content

SD3 Transformer Model

The Transformer model introduced in Stable Diffusion 3. Its novelty lies in the MMDiT transformer block.

mindone.diffusers.SD3Transformer2DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin

The Transformer model introduced in Stable Diffusion 3.

Reference: https://arxiv.org/abs/2403.03206

PARAMETER DESCRIPTION
sample_size

The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings.

TYPE: `int` DEFAULT: 128

patch_size

Patch size to turn the input data into small patches.

TYPE: `int` DEFAULT: 2

in_channels

The number of channels in the input.

TYPE: `int`, *optional*, defaults to 16 DEFAULT: 16

num_layers

The number of layers of Transformer blocks to use.

TYPE: `int`, *optional*, defaults to 18 DEFAULT: 18

attention_head_dim

The number of channels in each head.

TYPE: `int`, *optional*, defaults to 64 DEFAULT: 64

num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, *optional*, defaults to 18 DEFAULT: 18

cross_attention_dim

The number of encoder_hidden_states dimensions to use.

TYPE: `int`, *optional*

caption_projection_dim

Number of dimensions to use when projecting the encoder_hidden_states.

TYPE: `int` DEFAULT: 1152

pooled_projection_dim

Number of dimensions to use when projecting the pooled_projections.

TYPE: `int` DEFAULT: 2048

out_channels

Number of output channels.

TYPE: `int`, defaults to 16 DEFAULT: 16

Source code in mindone/diffusers/models/transformers/transformer_sd3.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    """
    The Transformer model introduced in Stable Diffusion 3.

    Reference: https://arxiv.org/abs/2403.03206

    Parameters:
        sample_size (`int`): The width of the latent images. This is fixed during training since
            it is used to learn a number of position embeddings.
        patch_size (`int`): Patch size to turn the input data into small patches.
        in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
        num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
        attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
        num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
        pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
        out_channels (`int`, defaults to 16): Number of output channels.

    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        sample_size: int = 128,
        patch_size: int = 2,
        in_channels: int = 16,
        num_layers: int = 18,
        attention_head_dim: int = 64,
        num_attention_heads: int = 18,
        joint_attention_dim: int = 4096,
        caption_projection_dim: int = 1152,
        pooled_projection_dim: int = 2048,
        out_channels: int = 16,
        pos_embed_max_size: int = 96,
    ):
        super().__init__()
        default_out_channels = in_channels
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

        self.pos_embed = PatchEmbed(
            height=self.config.sample_size,
            width=self.config.sample_size,
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            embed_dim=self.inner_dim,
            pos_embed_max_size=pos_embed_max_size,  # hard-code for now.
        )
        self.time_text_embed = CombinedTimestepTextProjEmbeddings(
            embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
        )
        self.context_embedder = nn.Dense(self.config.joint_attention_dim, self.config.caption_projection_dim)

        # `attention_head_dim` is doubled to account for the mixing.
        # It needs to crafted when we get the actual checkpoints.
        self.transformer_blocks = nn.CellList(
            [
                JointTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    context_pre_only=i == num_layers - 1,
                )
                for i in range(self.config.num_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Dense(self.inner_dim, patch_size * patch_size * self.out_channels, has_bias=True)

        self._gradient_checkpointing = False

    @property
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            for sub_name, child in module.name_cells().items():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.name_cells().items():
            fn_recursive_add_processors(name, module, processors)

        return processors

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.name_cells().items():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.name_cells().items():
            fn_recursive_attn_processor(name, module, processor)

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        self.original_attn_processors = None

        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        self.original_attn_processors = self.attn_processors

        for _, module in self.cells_and_names():
            if isinstance(module, Attention):
                module.fuse_projections(fuse=True)

        self.set_attn_processor(FusedJointAttnProcessor2_0())

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        """
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    @property
    def gradient_checkpointing(self):
        return self._gradient_checkpointing

    @gradient_checkpointing.setter
    def gradient_checkpointing(self, value):
        self._gradient_checkpointing = value
        for block in self.transformer_blocks:
            block._recompute(value)

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: ms.Tensor = None,
        pooled_projections: ms.Tensor = None,
        timestep: ms.Tensor = None,
        block_controlnet_hidden_states: List = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = False,
    ) -> Union[ms.Tensor, Transformer2DModelOutput]:
        """
        The [`SD3Transformer2DModel`] forward method.

        Args:
            hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`):
                Input `hidden_states`.
            encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
                from the embeddings of input conditions.
            timestep ( `ms.Tensor`):
                Used to indicate denoising step.
            block_controlnet_hidden_states: (`list` of `mindspore.Tensor`):
                A list of tensors that if specified are added to the residuals of transformer blocks.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        if joint_attention_kwargs is not None and "scale" in joint_attention_kwargs:
            # weight the lora layers by setting `lora_scale` for each PEFT layer here
            # and remove `lora_scale` from each PEFT layer at the end.
            # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
            raise RuntimeError(
                f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. "
                f"However it's not allowed in on-the-fly model forwarding. "
                f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
                f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
                f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
            )

        height, width = hidden_states.shape[-2:]

        hidden_states = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too.
        temb = self.time_text_embed(timestep, pooled_projections)
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)

        for index_block, block in enumerate(self.transformer_blocks):
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
            )

            # controlnet residual
            if block_controlnet_hidden_states is not None and block.context_pre_only is False:
                interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
                hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]

        hidden_states = self.norm_out(hidden_states, temb)
        hidden_states = self.proj_out(hidden_states)

        # unpatchify
        patch_size = self.config["patch_size"]
        height = height // patch_size
        width = width // patch_size

        hidden_states = hidden_states.reshape(
            hidden_states.shape[0],
            height,
            width,
            patch_size,
            patch_size,
            self.out_channels,
        )
        # hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
        hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
        output = hidden_states.reshape(
            hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size
        )

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.SD3Transformer2DModel.attn_processors: Dict[str, AttentionProcessor] property

RETURNS DESCRIPTION
Dict[str, AttentionProcessor]

dict of attention processors: A dictionary containing all attention processors used in the model with

Dict[str, AttentionProcessor]

indexed by its weight name.

mindone.diffusers.SD3Transformer2DModel.construct(hidden_states, encoder_hidden_states=None, pooled_projections=None, timestep=None, block_controlnet_hidden_states=None, joint_attention_kwargs=None, return_dict=False)

The [SD3Transformer2DModel] forward method.

PARAMETER DESCRIPTION
hidden_states

Input hidden_states.

TYPE: `ms.Tensor` of shape `(batch size, channel, height, width)`

encoder_hidden_states

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

TYPE: `ms.Tensor` of shape `(batch size, sequence_len, embed_dims)` DEFAULT: None

pooled_projections

Embeddings projected from the embeddings of input conditions.

TYPE: `ms.Tensor` of shape `(batch_size, projection_dim)` DEFAULT: None

timestep

Used to indicate denoising step.

TYPE: `ms.Tensor` DEFAULT: None

block_controlnet_hidden_states

(list of mindspore.Tensor): A list of tensors that if specified are added to the residuals of transformer blocks.

TYPE: List DEFAULT: None

joint_attention_kwargs

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

TYPE: `dict`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[Tensor, Transformer2DModelOutput]

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Union[Tensor, Transformer2DModelOutput]

tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/transformer_sd3.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def construct(
    self,
    hidden_states: ms.Tensor,
    encoder_hidden_states: ms.Tensor = None,
    pooled_projections: ms.Tensor = None,
    timestep: ms.Tensor = None,
    block_controlnet_hidden_states: List = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = False,
) -> Union[ms.Tensor, Transformer2DModelOutput]:
    """
    The [`SD3Transformer2DModel`] forward method.

    Args:
        hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`):
            Input `hidden_states`.
        encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
            Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
        pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
            from the embeddings of input conditions.
        timestep ( `ms.Tensor`):
            Used to indicate denoising step.
        block_controlnet_hidden_states: (`list` of `mindspore.Tensor`):
            A list of tensors that if specified are added to the residuals of transformer blocks.
        joint_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
            tuple.

    Returns:
        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
        `tuple` where the first element is the sample tensor.
    """
    if joint_attention_kwargs is not None and "scale" in joint_attention_kwargs:
        # weight the lora layers by setting `lora_scale` for each PEFT layer here
        # and remove `lora_scale` from each PEFT layer at the end.
        # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
        raise RuntimeError(
            f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. "
            f"However it's not allowed in on-the-fly model forwarding. "
            f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
            f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
            f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
        )

    height, width = hidden_states.shape[-2:]

    hidden_states = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too.
    temb = self.time_text_embed(timestep, pooled_projections)
    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    for index_block, block in enumerate(self.transformer_blocks):
        encoder_hidden_states, hidden_states = block(
            hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
        )

        # controlnet residual
        if block_controlnet_hidden_states is not None and block.context_pre_only is False:
            interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
            hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]

    hidden_states = self.norm_out(hidden_states, temb)
    hidden_states = self.proj_out(hidden_states)

    # unpatchify
    patch_size = self.config["patch_size"]
    height = height // patch_size
    width = width // patch_size

    hidden_states = hidden_states.reshape(
        hidden_states.shape[0],
        height,
        width,
        patch_size,
        patch_size,
        self.out_channels,
    )
    # hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
    hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
    output = hidden_states.reshape(
        hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size
    )

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)

mindone.diffusers.SD3Transformer2DModel.fuse_qkv_projections()

Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

This API is 🧪 experimental.

Source code in mindone/diffusers/models/transformers/transformer_sd3.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def fuse_qkv_projections(self):
    """
    Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
    are fused. For cross-attention modules, key and value projection matrices are fused.

    <Tip warning={true}>

    This API is 🧪 experimental.

    </Tip>
    """
    self.original_attn_processors = None

    for _, attn_processor in self.attn_processors.items():
        if "Added" in str(attn_processor.__class__.__name__):
            raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

    self.original_attn_processors = self.attn_processors

    for _, module in self.cells_and_names():
        if isinstance(module, Attention):
            module.fuse_projections(fuse=True)

    self.set_attn_processor(FusedJointAttnProcessor2_0())

mindone.diffusers.SD3Transformer2DModel.set_attn_processor(processor)

Sets the attention processor to use to compute attention.

PARAMETER DESCRIPTION
processor

The instantiated processor class or a dictionary of processor classes that will be set as the processor for all Attention layers.

If processor is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.

TYPE: `dict` of `AttentionProcessor` or only `AttentionProcessor`

Source code in mindone/diffusers/models/transformers/transformer_sd3.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
    r"""
    Sets the attention processor to use to compute attention.

    Parameters:
        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
            The instantiated processor class or a dictionary of processor classes that will be set as the processor
            for **all** `Attention` layers.

            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
            processor. This is strongly recommended when setting trainable attention processors.

    """
    count = len(self.attn_processors.keys())

    if isinstance(processor, dict) and len(processor) != count:
        raise ValueError(
            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
        )

    def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
        if hasattr(module, "set_processor"):
            if not isinstance(processor, dict):
                module.set_processor(processor)
            else:
                module.set_processor(processor.pop(f"{name}.processor"))

        for sub_name, child in module.name_cells().items():
            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

    for name, module in self.name_cells().items():
        fn_recursive_attn_processor(name, module, processor)

mindone.diffusers.SD3Transformer2DModel.unfuse_qkv_projections()

Disables the fused QKV projection if enabled.

This API is 🧪 experimental.

Source code in mindone/diffusers/models/transformers/transformer_sd3.py
196
197
198
199
200
201
202
203
204
205
206
207
def unfuse_qkv_projections(self):
    """Disables the fused QKV projection if enabled.

    <Tip warning={true}>

    This API is 🧪 experimental.

    </Tip>

    """
    if self.original_attn_processors is not None:
        self.set_attn_processor(self.original_attn_processors)