Skip to content

FluxTransformer2DModel

A Transformer model for image-like data from Flux.

mindone.diffusers.models.transformers.transformer_flux.FluxTransformer2DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin

The Transformer model introduced in Flux.

Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

PARAMETER DESCRIPTION
patch_size

Patch size to turn the input data into small patches.

TYPE: `int` DEFAULT: 1

in_channels

The number of channels in the input.

TYPE: `int`, *optional*, defaults to 16 DEFAULT: 64

num_layers

The number of layers of MMDiT blocks to use.

TYPE: `int`, *optional*, defaults to 18 DEFAULT: 19

num_single_layers

The number of layers of single DiT blocks to use.

TYPE: `int`, *optional*, defaults to 18 DEFAULT: 38

attention_head_dim

The number of channels in each head.

TYPE: `int`, *optional*, defaults to 64 DEFAULT: 128

num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, *optional*, defaults to 18 DEFAULT: 24

joint_attention_dim

The number of encoder_hidden_states dimensions to use.

TYPE: `int`, *optional* DEFAULT: 4096

pooled_projection_dim

Number of dimensions to use when projecting the pooled_projections.

TYPE: `int` DEFAULT: 768

guidance_embeds

Whether to use guidance embeddings.

TYPE: `bool`, defaults to False DEFAULT: False

Source code in mindone/diffusers/models/transformers/transformer_flux.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    """
    The Transformer model introduced in Flux.

    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

    Parameters:
        patch_size (`int`): Patch size to turn the input data into small patches.
        in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
        num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
        num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
        attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
        num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
        joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
        guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        patch_size: int = 1,
        in_channels: int = 64,
        num_layers: int = 19,
        num_single_layers: int = 38,
        attention_head_dim: int = 128,
        num_attention_heads: int = 24,
        joint_attention_dim: int = 4096,
        pooled_projection_dim: int = 768,
        guidance_embeds: bool = False,
        axes_dims_rope: List[int] = [16, 56, 56],
    ):
        super().__init__()
        self.out_channels = in_channels
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

        self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
        text_time_guidance_cls = (
            CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
        )
        self.time_text_embed = text_time_guidance_cls(
            embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
        )

        self.context_embedder = nn.Dense(self.config.joint_attention_dim, self.inner_dim)
        self.x_embedder = nn.Dense(self.config.in_channels, self.inner_dim)

        self.transformer_blocks = nn.CellList(
            [
                FluxTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                )
                for i in range(self.config.num_layers)
            ]
        )

        self.single_transformer_blocks = nn.CellList(
            [
                FluxSingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                )
                for i in range(self.config.num_single_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Dense(self.inner_dim, patch_size * patch_size * self.out_channels, has_bias=True)

        self._gradient_checkpointing = False

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    @property
    def gradient_checkpointing(self):
        return self._gradient_checkpointing

    @gradient_checkpointing.setter
    def gradient_checkpointing(self, value):
        if self._gradient_checkpointing != value:
            self._gradient_checkpointing = value
            for block in self.transformer_blocks:
                block.recompute()
            for block in self.single_transformer_blocks:
                block.recompute()

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: ms.Tensor = None,
        pooled_projections: ms.Tensor = None,
        timestep: ms.Tensor = None,
        img_ids: ms.Tensor = None,
        txt_ids: ms.Tensor = None,
        guidance: ms.Tensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = False,
    ) -> Union[ms.Tensor, Transformer2DModelOutput]:
        """
        The [`FluxTransformer2DModel`] forward method.

        Args:
            hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`):
                Input `hidden_states`.
            encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
                from the embeddings of input conditions.
            timestep ( `ms.Tensor`):
                Used to indicate denoising step.
            block_controlnet_hidden_states: (`list` of `ms.Tensor`):
                A list of tensors that if specified are added to the residuals of transformer blocks.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
            raise RuntimeError(
                "Passing `scale` to `FluxTransformer2DModel` via `joint_attention_kwargs` is not supported "
                "for limitation of static graph syntax. Do it in LoRA inference/training scripts instead."
            )
        hidden_states = self.x_embedder(hidden_states)

        timestep = timestep.to(hidden_states.dtype) * 1000
        if guidance is not None:
            guidance = guidance.to(hidden_states.dtype) * 1000
        else:
            guidance = None
        temb = (
            self.time_text_embed(timestep, pooled_projections)
            if guidance is None
            else self.time_text_embed(timestep, guidance, pooled_projections)
        )
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)

        ids = ops.cat((txt_ids, img_ids), axis=1)
        image_rotary_emb = self.pos_embed(ids)

        for index_block, block in enumerate(self.transformer_blocks):
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
            )

        hidden_states = ops.cat([encoder_hidden_states, hidden_states], axis=1)

        for index_block, block in enumerate(self.single_transformer_blocks):
            hidden_states = block(
                hidden_states=hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
            )

        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

        hidden_states = self.norm_out(hidden_states, temb)
        output = self.proj_out(hidden_states)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.models.transformers.transformer_flux.FluxTransformer2DModel.construct(hidden_states, encoder_hidden_states=None, pooled_projections=None, timestep=None, img_ids=None, txt_ids=None, guidance=None, joint_attention_kwargs=None, return_dict=False)

The [FluxTransformer2DModel] forward method.

PARAMETER DESCRIPTION
hidden_states

Input hidden_states.

TYPE: `ms.Tensor` of shape `(batch size, channel, height, width)`

encoder_hidden_states

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

TYPE: `ms.Tensor` of shape `(batch size, sequence_len, embed_dims)` DEFAULT: None

pooled_projections

Embeddings projected from the embeddings of input conditions.

TYPE: `ms.Tensor` of shape `(batch_size, projection_dim)` DEFAULT: None

timestep

Used to indicate denoising step.

TYPE: `ms.Tensor` DEFAULT: None

block_controlnet_hidden_states

(list of ms.Tensor): A list of tensors that if specified are added to the residuals of transformer blocks.

joint_attention_kwargs

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

TYPE: `dict`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[Tensor, Transformer2DModelOutput]

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Union[Tensor, Transformer2DModelOutput]

tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/transformer_flux.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def construct(
    self,
    hidden_states: ms.Tensor,
    encoder_hidden_states: ms.Tensor = None,
    pooled_projections: ms.Tensor = None,
    timestep: ms.Tensor = None,
    img_ids: ms.Tensor = None,
    txt_ids: ms.Tensor = None,
    guidance: ms.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = False,
) -> Union[ms.Tensor, Transformer2DModelOutput]:
    """
    The [`FluxTransformer2DModel`] forward method.

    Args:
        hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`):
            Input `hidden_states`.
        encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
            Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
        pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
            from the embeddings of input conditions.
        timestep ( `ms.Tensor`):
            Used to indicate denoising step.
        block_controlnet_hidden_states: (`list` of `ms.Tensor`):
            A list of tensors that if specified are added to the residuals of transformer blocks.
        joint_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
            tuple.

    Returns:
        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
        `tuple` where the first element is the sample tensor.
    """
    if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
        raise RuntimeError(
            "Passing `scale` to `FluxTransformer2DModel` via `joint_attention_kwargs` is not supported "
            "for limitation of static graph syntax. Do it in LoRA inference/training scripts instead."
        )
    hidden_states = self.x_embedder(hidden_states)

    timestep = timestep.to(hidden_states.dtype) * 1000
    if guidance is not None:
        guidance = guidance.to(hidden_states.dtype) * 1000
    else:
        guidance = None
    temb = (
        self.time_text_embed(timestep, pooled_projections)
        if guidance is None
        else self.time_text_embed(timestep, guidance, pooled_projections)
    )
    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    ids = ops.cat((txt_ids, img_ids), axis=1)
    image_rotary_emb = self.pos_embed(ids)

    for index_block, block in enumerate(self.transformer_blocks):
        encoder_hidden_states, hidden_states = block(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            temb=temb,
            image_rotary_emb=image_rotary_emb,
        )

    hidden_states = ops.cat([encoder_hidden_states, hidden_states], axis=1)

    for index_block, block in enumerate(self.single_transformer_blocks):
        hidden_states = block(
            hidden_states=hidden_states,
            temb=temb,
            image_rotary_emb=image_rotary_emb,
        )

    hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)