Skip to content

LatteTransformer3DModel

A Diffusion Transformer model for 3D data from Latte.

mindone.diffusers.LatteTransformer3DModel

Bases: ModelMixin, ConfigMixin

Source code in mindone/diffusers/models/transformers/latte_transformer_3d.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    """
    A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
    https://github.com/Vchitect/Latte

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            The number of channels in the input.
        out_channels (`int`, *optional*):
            The number of channels in the output.
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        attention_bias (`bool`, *optional*):
            Configure if the `TransformerBlocks` attention should contain a bias parameter.
        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
            This is fixed during training since it is used to learn a number of position embeddings.
        patch_size (`int`, *optional*):
            The size of the patches to use in the patch embedding layer.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
        num_embeds_ada_norm ( `int`, *optional*):
            The number of diffusion steps used during training. Pass if at least one of the norm_layers is
            `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
            added to the hidden states. During inference, you can denoise for up to but not more steps than
            `num_embeds_ada_norm`.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
            Whether or not to use elementwise affine in normalization layers.
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
        caption_channels (`int`, *optional*):
            The number of channels in the caption embeddings.
        video_length (`int`, *optional*):
            The number of frames in the video-like data.
    """

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        sample_size: int = 64,
        patch_size: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        norm_type: str = "layer_norm",
        norm_elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
        caption_channels: int = None,
        video_length: int = 16,
    ):
        super().__init__()
        inner_dim = num_attention_heads * attention_head_dim

        # 1. Define input layers
        self.height = sample_size
        self.width = sample_size

        interpolation_scale = self.config.sample_size // 64
        interpolation_scale = max(interpolation_scale, 1)
        self.pos_embed = PatchEmbed(
            height=sample_size,
            width=sample_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=inner_dim,
            interpolation_scale=interpolation_scale,
        )
        self.patch_size = self.config.patch_size

        # 2. Define spatial transformers blocks
        self.transformer_blocks = nn.CellList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    norm_type=norm_type,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                )
                for d in range(num_layers)
            ]
        )

        # 3. Define temporal transformers blocks
        self.temporal_transformer_blocks = nn.CellList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=None,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    norm_type=norm_type,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                )
                for d in range(num_layers)
            ]
        )

        # 4. Define output layers
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm_out = LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
        self.scale_shift_table = ms.Parameter(ops.randn((2, inner_dim)) / inner_dim**0.5)
        self.proj_out = nn.Dense(inner_dim, patch_size * patch_size * self.out_channels)

        # 5. Latte other blocks.
        self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
        self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)

        # define temporal positional embedding
        temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
            inner_dim, ops.arange(0, video_length).unsqueeze(1).numpy()
        )  # 1152 hidden size
        self.temp_pos_embed = ms.Tensor.from_numpy(temp_pos_embed).float().unsqueeze(0)

        self.gradient_checkpointing = False

    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = value

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: Optional[ms.Tensor] = None,
        encoder_hidden_states: Optional[ms.Tensor] = None,
        encoder_attention_mask: Optional[ms.Tensor] = None,
        enable_temporal_attentions: bool = True,
        return_dict: bool = False,
    ):
        """
        The [`LatteTransformer3DModel`] forward method.

        Args:
            hidden_states shape `(batch size, channel, num_frame, height, width)`:
                Input `hidden_states`.
            timestep ( `ms.Tensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            encoder_attention_mask ( `ms.Tensor`, *optional*):
                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:

                    * Mask `(batcheight, sequence_length)` True = keep, False = discard.
                    * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.

                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
                above. This bias will be added to the cross-attention scores.
            enable_temporal_attentions:
                (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """

        # Reshape hidden states
        batch_size, channels, num_frame, height, width = hidden_states.shape
        # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)

        # Input
        height, width = (
            hidden_states.shape[-2] // self.patch_size,
            hidden_states.shape[-1] // self.patch_size,
        )
        num_patches = height * width

        hidden_states = self.pos_embed(hidden_states)  # alrady add positional embeddings

        added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
        timestep, embedded_timestep = self.adaln_single(
            timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
        )

        # Prepare text embeddings for spatial block
        # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
        encoder_hidden_states = self.caption_projection(encoder_hidden_states)  # 3 120 1152
        encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
            -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
        )

        # Prepare timesteps for spatial and temporal block
        timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
        timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])

        # Spatial and temporal transformer blocks
        for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
            hidden_states = spatial_block(
                hidden_states,
                None,  # attention_mask
                encoder_hidden_states_spatial,
                encoder_attention_mask,
                timestep_spatial,
                None,  # cross_attention_kwargs
                None,  # class_labels
            )

            if enable_temporal_attentions:
                # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
                hidden_states = hidden_states.reshape(
                    batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
                ).permute(0, 2, 1, 3)
                hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

                if i == 0 and num_frame > 1:
                    hidden_states = (hidden_states + self.temp_pos_embed).to(hidden_states.dtype)

                hidden_states = temp_block(
                    hidden_states,
                    None,  # attention_mask
                    None,  # encoder_hidden_states
                    None,  # encoder_attention_mask
                    timestep_temp,
                    None,  # cross_attention_kwargs
                    None,  # class_labels
                )

                # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
                hidden_states = hidden_states.reshape(
                    batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
                ).permute(0, 2, 1, 3)
                hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

        embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
        shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1)
        hidden_states = self.norm_out(hidden_states)
        # Modulation
        hidden_states = hidden_states * (1 + scale) + shift
        hidden_states = self.proj_out(hidden_states)

        # unpatchify
        if self.adaln_single is None:
            height = width = int(hidden_states.shape[1] ** 0.5)
        hidden_states = hidden_states.reshape((-1, height, width, self.patch_size, self.patch_size, self.out_channels))
        hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
        output = hidden_states.reshape((-1, self.out_channels, height * self.patch_size, width * self.patch_size))
        output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
            0, 2, 1, 3, 4
        )

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.LatteTransformer3DModel.construct(hidden_states, timestep=None, encoder_hidden_states=None, encoder_attention_mask=None, enable_temporal_attentions=True, return_dict=False)

The [LatteTransformer3DModel] forward method.

PARAMETER DESCRIPTION
hidden_states

Input hidden_states.

TYPE: shape `(batch size, channel, num_frame, height, width)`

timestep

Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

encoder_hidden_states

Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention.

TYPE: `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional* DEFAULT: None

encoder_attention_mask

Cross-attention mask applied to encoder_hidden_states. Two formats supported:

* Mask `(batcheight, sequence_length)` True = keep, False = discard.
* Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.

If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

enable_temporal_attentions

(bool, optional, defaults to True): Whether to enable temporal attentions.

TYPE: bool DEFAULT: True

return_dict

Whether or not to return a [~models.unet_2d_condition.UNet2DConditionOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/latte_transformer_3d.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def construct(
    self,
    hidden_states: ms.Tensor,
    timestep: Optional[ms.Tensor] = None,
    encoder_hidden_states: Optional[ms.Tensor] = None,
    encoder_attention_mask: Optional[ms.Tensor] = None,
    enable_temporal_attentions: bool = True,
    return_dict: bool = False,
):
    """
    The [`LatteTransformer3DModel`] forward method.

    Args:
        hidden_states shape `(batch size, channel, num_frame, height, width)`:
            Input `hidden_states`.
        timestep ( `ms.Tensor`, *optional*):
            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
        encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
            self-attention.
        encoder_attention_mask ( `ms.Tensor`, *optional*):
            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:

                * Mask `(batcheight, sequence_length)` True = keep, False = discard.
                * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.

            If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
            above. This bias will be added to the cross-attention scores.
        enable_temporal_attentions:
            (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
            tuple.

    Returns:
        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
        `tuple` where the first element is the sample tensor.
    """

    # Reshape hidden states
    batch_size, channels, num_frame, height, width = hidden_states.shape
    # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
    hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)

    # Input
    height, width = (
        hidden_states.shape[-2] // self.patch_size,
        hidden_states.shape[-1] // self.patch_size,
    )
    num_patches = height * width

    hidden_states = self.pos_embed(hidden_states)  # alrady add positional embeddings

    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
    timestep, embedded_timestep = self.adaln_single(
        timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
    )

    # Prepare text embeddings for spatial block
    # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
    encoder_hidden_states = self.caption_projection(encoder_hidden_states)  # 3 120 1152
    encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
        -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
    )

    # Prepare timesteps for spatial and temporal block
    timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
    timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])

    # Spatial and temporal transformer blocks
    for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
        hidden_states = spatial_block(
            hidden_states,
            None,  # attention_mask
            encoder_hidden_states_spatial,
            encoder_attention_mask,
            timestep_spatial,
            None,  # cross_attention_kwargs
            None,  # class_labels
        )

        if enable_temporal_attentions:
            # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
            hidden_states = hidden_states.reshape(
                batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
            ).permute(0, 2, 1, 3)
            hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

            if i == 0 and num_frame > 1:
                hidden_states = (hidden_states + self.temp_pos_embed).to(hidden_states.dtype)

            hidden_states = temp_block(
                hidden_states,
                None,  # attention_mask
                None,  # encoder_hidden_states
                None,  # encoder_attention_mask
                timestep_temp,
                None,  # cross_attention_kwargs
                None,  # class_labels
            )

            # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
            hidden_states = hidden_states.reshape(
                batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
            ).permute(0, 2, 1, 3)
            hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

    embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
    shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1)
    hidden_states = self.norm_out(hidden_states)
    # Modulation
    hidden_states = hidden_states * (1 + scale) + shift
    hidden_states = self.proj_out(hidden_states)

    # unpatchify
    if self.adaln_single is None:
        height = width = int(hidden_states.shape[1] ** 0.5)
    hidden_states = hidden_states.reshape((-1, height, width, self.patch_size, self.patch_size, self.out_channels))
    hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
    output = hidden_states.reshape((-1, self.out_channels, height * self.patch_size, width * self.patch_size))
    output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
        0, 2, 1, 3, 4
    )

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)