Skip to content

Flux2Transformer2DModel

A Transformer model for image-like data from Flux2.

mindone.diffusers.Flux2Transformer2DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin, AttentionMixin

The Transformer model introduced in Flux 2.

Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

PARAMETER DESCRIPTION
patch_size

Patch size to turn the input data into small patches.

TYPE: `int`, defaults to `1` DEFAULT: 1

in_channels

The number of channels in the input.

TYPE: `int`, defaults to `128` DEFAULT: 128

out_channels

The number of channels in the output. If not specified, it defaults to in_channels.

TYPE: `int`, *optional*, defaults to `None` DEFAULT: None

num_layers

The number of layers of dual stream DiT blocks to use.

TYPE: `int`, defaults to `8` DEFAULT: 8

num_single_layers

The number of layers of single stream DiT blocks to use.

TYPE: `int`, defaults to `48` DEFAULT: 48

attention_head_dim

The number of dimensions to use for each attention head.

TYPE: `int`, defaults to `128` DEFAULT: 128

num_attention_heads

The number of attention heads to use.

TYPE: `int`, defaults to `48` DEFAULT: 48

joint_attention_dim

The number of dimensions to use for the joint attention (embedding/channel dimension of encoder_hidden_states).

TYPE: `int`, defaults to `15360` DEFAULT: 15360

pooled_projection_dim

The number of dimensions to use for the pooled projection.

TYPE: `int`, defaults to `768`

guidance_embeds

Whether to use guidance embeddings for guidance-distilled variant of the model.

TYPE: `bool`, defaults to `True`

axes_dims_rope

The dimensions to use for the rotary positional embeddings.

TYPE: `Tuple[int]`, defaults to `(32, 32, 32, 32)` DEFAULT: (32, 32, 32, 32)

Source code in mindone/diffusers/models/transformers/transformer_flux2.py
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
class Flux2Transformer2DModel(
    ModelMixin,
    ConfigMixin,
    PeftAdapterMixin,
    FromOriginalModelMixin,
    FluxTransformer2DLoadersMixin,
    CacheMixin,
    AttentionMixin,
):
    """
    The Transformer model introduced in Flux 2.

    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

    Args:
        patch_size (`int`, defaults to `1`):
            Patch size to turn the input data into small patches.
        in_channels (`int`, defaults to `128`):
            The number of channels in the input.
        out_channels (`int`, *optional*, defaults to `None`):
            The number of channels in the output. If not specified, it defaults to `in_channels`.
        num_layers (`int`, defaults to `8`):
            The number of layers of dual stream DiT blocks to use.
        num_single_layers (`int`, defaults to `48`):
            The number of layers of single stream DiT blocks to use.
        attention_head_dim (`int`, defaults to `128`):
            The number of dimensions to use for each attention head.
        num_attention_heads (`int`, defaults to `48`):
            The number of attention heads to use.
        joint_attention_dim (`int`, defaults to `15360`):
            The number of dimensions to use for the joint attention (embedding/channel dimension of
            `encoder_hidden_states`).
        pooled_projection_dim (`int`, defaults to `768`):
            The number of dimensions to use for the pooled projection.
        guidance_embeds (`bool`, defaults to `True`):
            Whether to use guidance embeddings for guidance-distilled variant of the model.
        axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
            The dimensions to use for the rotary positional embeddings.
    """

    _supports_gradient_checkpointing = True
    _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
    _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
    _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
    _cp_plan = {
        "": {
            "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
            "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
            "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
            "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
        },
        "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
    }

    @register_to_config
    def __init__(
        self,
        patch_size: int = 1,
        in_channels: int = 128,
        out_channels: Optional[int] = None,
        num_layers: int = 8,
        num_single_layers: int = 48,
        attention_head_dim: int = 128,
        num_attention_heads: int = 48,
        joint_attention_dim: int = 15360,
        timestep_guidance_channels: int = 256,
        mlp_ratio: float = 3.0,
        axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
        rope_theta: int = 2000,
        eps: float = 1e-6,
    ):
        super().__init__()
        self.out_channels = out_channels or in_channels
        self.inner_dim = num_attention_heads * attention_head_dim

        # 1. Sinusoidal positional embedding for RoPE on image and text tokens
        self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)

        # 2. Combined timestep + guidance embedding
        self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
            in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
        )

        # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
        # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
        self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
        self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
        # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
        self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)

        # 4. Input projections
        self.x_embedder = mint.nn.Linear(in_channels, self.inner_dim, bias=False)
        self.context_embedder = mint.nn.Linear(joint_attention_dim, self.inner_dim, bias=False)

        # 5. Double Stream Transformer Blocks
        self.transformer_blocks = nn.CellList(
            [
                Flux2TransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    mlp_ratio=mlp_ratio,
                    eps=eps,
                    bias=False,
                )
                for _ in range(num_layers)
            ]
        )

        # 6. Single Stream Transformer Blocks
        self.single_transformer_blocks = nn.CellList(
            [
                Flux2SingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    mlp_ratio=mlp_ratio,
                    eps=eps,
                    bias=False,
                )
                for _ in range(num_single_layers)
            ]
        )

        # 7. Output layers
        self.norm_out = AdaLayerNormContinuous(
            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
        )
        self.proj_out = mint.nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: ms.Tensor = None,
        timestep: ms.Tensor = None,
        img_ids: ms.Tensor = None,
        txt_ids: ms.Tensor = None,
        guidance: ms.Tensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,
    ) -> Union[ms.Tensor, Transformer2DModelOutput]:
        """
        The [`FluxTransformer2DModel`] forward method.

        Args:
            hidden_states (`ms.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
                Input `hidden_states`.
            encoder_hidden_states (`ms.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            timestep ( `ms.Tensor`):
                Used to indicate denoising step.
            block_controlnet_hidden_states: (`list` of `ms.Tensor`):
                A list of tensors that if specified are added to the residuals of transformer blocks.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        # 0. Handle input arguments
        if joint_attention_kwargs is not None:
            # weight the lora layers by setting `lora_scale` for each PEFT layer here
            # and remove `lora_scale` from each PEFT layer at the end.
            # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
            raise RuntimeError(
                f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. "
                f"However it's not allowed in on-the-fly model forwarding. "
                f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
                f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
                f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
            )

        num_txt_tokens = encoder_hidden_states.shape[1]

        # 1. Calculate timestep embedding and modulation parameters
        timestep = timestep.to(hidden_states.dtype) * 1000
        guidance = guidance.to(hidden_states.dtype) * 1000

        temb = self.time_guidance_embed(timestep, guidance)

        double_stream_mod_img = self.double_stream_modulation_img(temb)
        double_stream_mod_txt = self.double_stream_modulation_txt(temb)
        single_stream_mod = self.single_stream_modulation(temb)[0]

        # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
        hidden_states = self.x_embedder(hidden_states)
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)

        # 3. Calculate RoPE embeddings from image and text tokens
        # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
        # text prompts of differents lengths. Is this a use case we want to support?
        if img_ids.ndim == 3:
            img_ids = img_ids[0]
        if txt_ids.ndim == 3:
            txt_ids = txt_ids[0]

        image_rotary_emb = self.pos_embed(img_ids)
        text_rotary_emb = self.pos_embed(txt_ids)
        concat_rotary_emb = (
            mint.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
            mint.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
        )

        # 4. Double Stream Transformer Blocks
        for index_block, block in enumerate(self.transformer_blocks):
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb_mod_params_img=double_stream_mod_img,
                temb_mod_params_txt=double_stream_mod_txt,
                image_rotary_emb=concat_rotary_emb,
                joint_attention_kwargs=joint_attention_kwargs,
            )
        # Concatenate text and image streams for single-block inference
        hidden_states = mint.cat([encoder_hidden_states, hidden_states], dim=1)

        # 5. Single Stream Transformer Blocks
        for index_block, block in enumerate(self.single_transformer_blocks):
            hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=None,
                temb_mod_params=single_stream_mod,
                image_rotary_emb=concat_rotary_emb,
                joint_attention_kwargs=joint_attention_kwargs,
            )
        # Remove text tokens from concatenated stream
        hidden_states = hidden_states[:, num_txt_tokens:, ...]

        # 6. Output layers
        hidden_states = self.norm_out(hidden_states, temb)
        output = self.proj_out(hidden_states)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.Flux2Transformer2DModel.construct(hidden_states, encoder_hidden_states=None, timestep=None, img_ids=None, txt_ids=None, guidance=None, joint_attention_kwargs=None, return_dict=True)

The [FluxTransformer2DModel] forward method.

PARAMETER DESCRIPTION
hidden_states

Input hidden_states.

TYPE: `ms.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`

encoder_hidden_states

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

TYPE: `ms.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)` DEFAULT: None

timestep

Used to indicate denoising step.

TYPE: `ms.Tensor` DEFAULT: None

block_controlnet_hidden_states

(list of ms.Tensor): A list of tensors that if specified are added to the residuals of transformer blocks.

joint_attention_kwargs

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

TYPE: `dict`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: True

RETURNS DESCRIPTION
Union[Tensor, Transformer2DModelOutput]

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Union[Tensor, Transformer2DModelOutput]

tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/transformer_flux2.py
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
def construct(
    self,
    hidden_states: ms.Tensor,
    encoder_hidden_states: ms.Tensor = None,
    timestep: ms.Tensor = None,
    img_ids: ms.Tensor = None,
    txt_ids: ms.Tensor = None,
    guidance: ms.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = True,
) -> Union[ms.Tensor, Transformer2DModelOutput]:
    """
    The [`FluxTransformer2DModel`] forward method.

    Args:
        hidden_states (`ms.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
            Input `hidden_states`.
        encoder_hidden_states (`ms.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
            Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
        timestep ( `ms.Tensor`):
            Used to indicate denoising step.
        block_controlnet_hidden_states: (`list` of `ms.Tensor`):
            A list of tensors that if specified are added to the residuals of transformer blocks.
        joint_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
            tuple.

    Returns:
        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
        `tuple` where the first element is the sample tensor.
    """
    # 0. Handle input arguments
    if joint_attention_kwargs is not None:
        # weight the lora layers by setting `lora_scale` for each PEFT layer here
        # and remove `lora_scale` from each PEFT layer at the end.
        # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
        raise RuntimeError(
            f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. "
            f"However it's not allowed in on-the-fly model forwarding. "
            f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
            f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
            f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
        )

    num_txt_tokens = encoder_hidden_states.shape[1]

    # 1. Calculate timestep embedding and modulation parameters
    timestep = timestep.to(hidden_states.dtype) * 1000
    guidance = guidance.to(hidden_states.dtype) * 1000

    temb = self.time_guidance_embed(timestep, guidance)

    double_stream_mod_img = self.double_stream_modulation_img(temb)
    double_stream_mod_txt = self.double_stream_modulation_txt(temb)
    single_stream_mod = self.single_stream_modulation(temb)[0]

    # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
    hidden_states = self.x_embedder(hidden_states)
    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    # 3. Calculate RoPE embeddings from image and text tokens
    # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
    # text prompts of differents lengths. Is this a use case we want to support?
    if img_ids.ndim == 3:
        img_ids = img_ids[0]
    if txt_ids.ndim == 3:
        txt_ids = txt_ids[0]

    image_rotary_emb = self.pos_embed(img_ids)
    text_rotary_emb = self.pos_embed(txt_ids)
    concat_rotary_emb = (
        mint.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
        mint.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
    )

    # 4. Double Stream Transformer Blocks
    for index_block, block in enumerate(self.transformer_blocks):
        encoder_hidden_states, hidden_states = block(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            temb_mod_params_img=double_stream_mod_img,
            temb_mod_params_txt=double_stream_mod_txt,
            image_rotary_emb=concat_rotary_emb,
            joint_attention_kwargs=joint_attention_kwargs,
        )
    # Concatenate text and image streams for single-block inference
    hidden_states = mint.cat([encoder_hidden_states, hidden_states], dim=1)

    # 5. Single Stream Transformer Blocks
    for index_block, block in enumerate(self.single_transformer_blocks):
        hidden_states = block(
            hidden_states=hidden_states,
            encoder_hidden_states=None,
            temb_mod_params=single_stream_mod,
            image_rotary_emb=concat_rotary_emb,
            joint_attention_kwargs=joint_attention_kwargs,
        )
    # Remove text tokens from concatenated stream
    hidden_states = hidden_states[:, num_txt_tokens:, ...]

    # 6. Output layers
    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)