Skip to content

DiTTransformer2DModel

A Transformer model for image-like data from DiT.

mindone.diffusers.DiTTransformer2DModel

Bases: ModelMixin, ConfigMixin

A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).

PARAMETER DESCRIPTION
num_attention_heads

The number of heads to use for multi-head attention.

TYPE: int, optional, defaults to 16 DEFAULT: 16

attention_head_dim

The number of channels in each head.

TYPE: int, optional, defaults to 72 DEFAULT: 72

in_channels

The number of channels in the input.

TYPE: int, defaults to 4 DEFAULT: 4

out_channels

The number of channels in the output. Specify this parameter if the output channel number differs from the input.

TYPE: int DEFAULT: None

num_layers

The number of layers of Transformer blocks to use.

TYPE: int, optional, defaults to 28 DEFAULT: 28

dropout

The dropout probability to use within the Transformer blocks.

TYPE: float, optional, defaults to 0.0 DEFAULT: 0.0

norm_num_groups

Number of groups for group normalization within Transformer blocks.

TYPE: int, optional, defaults to 32 DEFAULT: 32

attention_bias

Configure if the Transformer blocks' attention should contain a bias parameter.

TYPE: bool, optional, defaults to True DEFAULT: True

sample_size

The width of the latent images. This parameter is fixed during training.

TYPE: int, defaults to 32 DEFAULT: 32

patch_size

Size of the patches the model processes, relevant for architectures working on non-sequential data.

TYPE: int, defaults to 2 DEFAULT: 2

activation_fn

Activation function to use in feed-forward networks within Transformer blocks.

TYPE: str, optional, defaults to "gelu-approximate" DEFAULT: 'gelu-approximate'

num_embeds_ada_norm

Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during inference.

TYPE: int, optional, defaults to 1000 DEFAULT: 1000

upcast_attention

If true, upcasts the attention mechanism dimensions for potentially improved performance.

TYPE: bool, optional, defaults to False DEFAULT: False

norm_type

Specifies the type of normalization used, can be 'ada_norm_zero'.

TYPE: str, optional, defaults to "ada_norm_zero" DEFAULT: 'ada_norm_zero'

norm_elementwise_affine

If true, enables element-wise affine parameters in the normalization layers.

TYPE: bool, optional, defaults to False DEFAULT: False

norm_eps

A small constant added to the denominator in normalization layers to prevent division by zero.

TYPE: float, optional, defaults to 1e-5 DEFAULT: 1e-05

Source code in mindone/diffusers/models/transformers/dit_transformer_2d.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
    r"""
    A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).

    Parameters:
        num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
        in_channels (int, defaults to 4): The number of channels in the input.
        out_channels (int, optional):
            The number of channels in the output. Specify this parameter if the output channel number differs from the
            input.
        num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
        dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
        norm_num_groups (int, optional, defaults to 32):
            Number of groups for group normalization within Transformer blocks.
        attention_bias (bool, optional, defaults to True):
            Configure if the Transformer blocks' attention should contain a bias parameter.
        sample_size (int, defaults to 32):
            The width of the latent images. This parameter is fixed during training.
        patch_size (int, defaults to 2):
            Size of the patches the model processes, relevant for architectures working on non-sequential data.
        activation_fn (str, optional, defaults to "gelu-approximate"):
            Activation function to use in feed-forward networks within Transformer blocks.
        num_embeds_ada_norm (int, optional, defaults to 1000):
            Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
            inference.
        upcast_attention (bool, optional, defaults to False):
            If true, upcasts the attention mechanism dimensions for potentially improved performance.
        norm_type (str, optional, defaults to "ada_norm_zero"):
            Specifies the type of normalization used, can be 'ada_norm_zero'.
        norm_elementwise_affine (bool, optional, defaults to False):
            If true, enables element-wise affine parameters in the normalization layers.
        norm_eps (float, optional, defaults to 1e-5):
            A small constant added to the denominator in normalization layers to prevent division by zero.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 72,
        in_channels: int = 4,
        out_channels: Optional[int] = None,
        num_layers: int = 28,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        attention_bias: bool = True,
        sample_size: int = 32,
        patch_size: int = 2,
        activation_fn: str = "gelu-approximate",
        num_embeds_ada_norm: Optional[int] = 1000,
        upcast_attention: bool = False,
        norm_type: str = "ada_norm_zero",
        norm_elementwise_affine: bool = False,
        norm_eps: float = 1e-5,
    ):
        super().__init__()

        # Validate inputs.
        if norm_type != "ada_norm_zero":
            raise NotImplementedError(
                f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
            )
        elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
            raise ValueError(
                f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
            )

        # Set some common variables used across the board.
        self.attention_head_dim = attention_head_dim
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
        self.out_channels = in_channels if out_channels is None else out_channels
        self.gradient_checkpointing = False

        # 2. Initialize the position embedding and transformer blocks.
        self.height = self.config.sample_size
        self.width = self.config.sample_size

        self.patch_size = self.config.patch_size
        self.pos_embed = PatchEmbed(
            height=self.config.sample_size,
            width=self.config.sample_size,
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            embed_dim=self.inner_dim,
        )

        self.transformer_blocks = nn.CellList(
            [
                BasicTransformerBlock(
                    self.inner_dim,
                    self.config.num_attention_heads,
                    self.config.attention_head_dim,
                    dropout=self.config.dropout,
                    activation_fn=self.config.activation_fn,
                    num_embeds_ada_norm=self.config.num_embeds_ada_norm,
                    attention_bias=self.config.attention_bias,
                    upcast_attention=self.config.upcast_attention,
                    norm_type=norm_type,
                    norm_elementwise_affine=self.config.norm_elementwise_affine,
                    norm_eps=self.config.norm_eps,
                )
                for _ in range(self.config.num_layers)
            ]
        )

        # 3. Output blocks.
        self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out_1 = nn.Dense(self.inner_dim, 2 * self.inner_dim)
        self.proj_out_2 = nn.Dense(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: Optional[ms.Tensor] = None,
        class_labels: Optional[ms.Tensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        return_dict: bool = False,
    ):
        """
        The [`DiTTransformer2DModel`] forward method.

        Args:
            hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous):  # noqa: E501
                Input `hidden_states`.
            timestep ( `ms.Tensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*):
                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
                `AdaLayerZeroNorm`.
            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        # 1. Input
        height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
        hidden_states = self.pos_embed(hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                attention_mask=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                class_labels=class_labels,
            )

        # 3. Output
        conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
        shift, scale = self.proj_out_1(ops.silu(conditioning)).chunk(2, axis=1)
        hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
        hidden_states = self.proj_out_2(hidden_states)

        # unpatchify
        height = width = int(hidden_states.shape[1] ** 0.5)
        hidden_states = hidden_states.reshape(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
        # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states)
        hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
        output = hidden_states.reshape(-1, self.out_channels, height * self.patch_size, width * self.patch_size)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.DiTTransformer2DModel.construct(hidden_states, timestep=None, class_labels=None, cross_attention_kwargs=None, return_dict=False)

The [DiTTransformer2DModel] forward method.

PARAMETER DESCRIPTION
hidden_states

noqa: E501

Input hidden_states.

TYPE: `ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous

timestep

Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

class_labels

Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in AdaLayerZeroNorm.

TYPE: `ms.Tensor` of shape `(batch size, num classes)`, *optional* DEFAULT: None

cross_attention_kwargs

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

TYPE: `Dict[str, Any]`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~models.unets.unet_2d_condition.UNet2DConditionOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/dit_transformer_2d.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def construct(
    self,
    hidden_states: ms.Tensor,
    timestep: Optional[ms.Tensor] = None,
    class_labels: Optional[ms.Tensor] = None,
    cross_attention_kwargs: Dict[str, Any] = None,
    return_dict: bool = False,
):
    """
    The [`DiTTransformer2DModel`] forward method.

    Args:
        hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous):  # noqa: E501
            Input `hidden_states`.
        timestep ( `ms.Tensor`, *optional*):
            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
        class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*):
            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
            `AdaLayerZeroNorm`.
        cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
            tuple.

    Returns:
        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
        `tuple` where the first element is the sample tensor.
    """
    # 1. Input
    height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
    hidden_states = self.pos_embed(hidden_states)

    # 2. Blocks
    for block in self.transformer_blocks:
        hidden_states = block(
            hidden_states,
            attention_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            timestep=timestep,
            cross_attention_kwargs=cross_attention_kwargs,
            class_labels=class_labels,
        )

    # 3. Output
    conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
    shift, scale = self.proj_out_1(ops.silu(conditioning)).chunk(2, axis=1)
    hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
    hidden_states = self.proj_out_2(hidden_states)

    # unpatchify
    height = width = int(hidden_states.shape[1] ** 0.5)
    hidden_states = hidden_states.reshape(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
    # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states)
    hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
    output = hidden_states.reshape(-1, self.out_channels, height * self.patch_size, width * self.patch_size)

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)