Skip to content

AllegroTransformer3DModel

A Diffusion Transformer model for 3D data from Allegro was introduced in Allegro: Open the Black Box of Commercial-Level Video Generation Model by RhymesAI.

The model can be loaded with the following code snippet.

from mindone.diffusers import AllegroTransformer3DModel

vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", mindspore_dtype=ms.bfloat16)

mindone.diffusers.AllegroTransformer3DModel

Bases: ModelMixin, ConfigMixin

Source code in mindone/diffusers/models/transformers/transformer_allegro.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    """
    A 3D Transformer model for video-like data.

    Args:
        patch_size (`int`, defaults to `2`):
            The size of spatial patches to use in the patch embedding layer.
        patch_size_t (`int`, defaults to `1`):
            The size of temporal patches to use in the patch embedding layer.
        num_attention_heads (`int`, defaults to `24`):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`, defaults to `96`):
            The number of channels in each head.
        in_channels (`int`, defaults to `4`):
            The number of channels in the input.
        out_channels (`int`, *optional*, defaults to `4`):
            The number of channels in the output.
        num_layers (`int`, defaults to `32`):
            The number of layers of Transformer blocks to use.
        dropout (`float`, defaults to `0.0`):
            The dropout probability to use.
        cross_attention_dim (`int`, defaults to `2304`):
            The dimension of the cross attention features.
        attention_bias (`bool`, defaults to `True`):
            Whether or not to use bias in the attention projection layers.
        sample_height (`int`, defaults to `90`):
            The height of the input latents.
        sample_width (`int`, defaults to `160`):
            The width of the input latents.
        sample_frames (`int`, defaults to `22`):
            The number of frames in the input latents.
        activation_fn (`str`, defaults to `"gelu-approximate"`):
            Activation function to use in feed-forward.
        norm_elementwise_affine (`bool`, defaults to `False`):
            Whether or not to use elementwise affine in normalization layers.
        norm_eps (`float`, defaults to `1e-6`):
            The epsilon value to use in normalization layers.
        caption_channels (`int`, defaults to `4096`):
            Number of channels to use for projecting the caption embeddings.
        interpolation_scale_h (`float`, defaults to `2.0`):
            Scaling factor to apply in 3D positional embeddings across height dimension.
        interpolation_scale_w (`float`, defaults to `2.0`):
            Scaling factor to apply in 3D positional embeddings across width dimension.
        interpolation_scale_t (`float`, defaults to `2.2`):
            Scaling factor to apply in 3D positional embeddings across time dimension.
    """

    @register_to_config
    def __init__(
        self,
        patch_size: int = 2,
        patch_size_t: int = 1,
        num_attention_heads: int = 24,
        attention_head_dim: int = 96,
        in_channels: int = 4,
        out_channels: int = 4,
        num_layers: int = 32,
        dropout: float = 0.0,
        cross_attention_dim: int = 2304,
        attention_bias: bool = True,
        sample_height: int = 90,
        sample_width: int = 160,
        sample_frames: int = 22,
        activation_fn: str = "gelu-approximate",
        norm_elementwise_affine: bool = False,
        norm_eps: float = 1e-6,
        caption_channels: int = 4096,
        interpolation_scale_h: float = 2.0,
        interpolation_scale_w: float = 2.0,
        interpolation_scale_t: float = 2.2,
    ):
        super().__init__()

        self.inner_dim = num_attention_heads * attention_head_dim

        interpolation_scale_t = (
            interpolation_scale_t
            if interpolation_scale_t is not None
            else ((sample_frames - 1) // 16 + 1)
            if sample_frames % 2 == 1
            else sample_frames // 16
        )
        interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
        interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40

        # 1. Patch embedding
        self.pos_embed = PatchEmbed(
            height=sample_height,
            width=sample_width,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=self.inner_dim,
            pos_embed_type=None,
        )

        # 2. Transformer blocks
        self.transformer_blocks = nn.CellList(
            [
                AllegroTransformerBlock(
                    self.inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                )
                for _ in range(num_layers)
            ]
        )

        # 3. Output projection & norm
        self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.scale_shift_table = ms.Parameter(
            ops.randn(2, self.inner_dim) / self.inner_dim**0.5, name="scale_shift_table"
        )
        self.proj_out = nn.Dense(self.inner_dim, patch_size * patch_size * out_channels)

        # 4. Timestep embeddings
        self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)

        # 5. Caption projection
        self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)

        self.gradient_checkpointing = False

    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = value

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        attention_mask: Optional[ms.Tensor] = None,
        encoder_attention_mask: Optional[ms.Tensor] = None,
        image_rotary_emb: Optional[Tuple[ms.Tensor, ms.Tensor]] = None,
        return_dict: bool = False,
    ):
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t = self.config["patch_size_t"]
        p = self.config["patch_size"]

        post_patch_num_frames = num_frames // p_t
        post_patch_height = height // p
        post_patch_width = width // p

        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)        attention_mask_vid, attention_mask_img = None, None
        if attention_mask is not None and attention_mask.ndim == 4:
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #   (keep = +0,     discard = -10000.0)
            # b, frame+use_image_num, h, w -> a video with images
            # b, 1, h, w -> only images
            attention_mask = attention_mask.to(hidden_states.dtype)
            attention_mask = attention_mask[:, :num_frames]  # [batch_size, num_frames, height, width]

            if attention_mask.numel() > 0:
                attention_mask = attention_mask.unsqueeze(1)  # [batch_size, 1, num_frames, height, width]
                attention_mask = ops.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
                attention_mask = attention_mask.flatten(start_dim=1).view(batch_size, 1, -1)

            attention_mask = (
                (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
            )

        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
            encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        # 1. Timestep embeddings
        timestep, embedded_timestep = self.adaln_single(
            timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
        )

        # 2. Patch embeddings
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(start_dim=0, end_dim=1)
        hidden_states = self.pos_embed(hidden_states)
        hidden_states = hidden_states.reshape(
            hidden_states.shape[:0] + (batch_size, -1) + hidden_states.shape[1:]
        ).flatten(start_dim=1, end_dim=2)

        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])

        # 3. Transformer blocks
        for i, block in enumerate(self.transformer_blocks):
            hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=timestep,
                attention_mask=attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                image_rotary_emb=image_rotary_emb,
            )

        # 4. Output normalization & projection
        shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1)
        hidden_states = self.norm_out(hidden_states)

        # Modulation
        hidden_states = hidden_states * (1 + scale) + shift
        hidden_states = self.proj_out(hidden_states)
        # If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged. This function is not supported in MS.
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        # 5. Unpatchify
        hidden_states = hidden_states.reshape(
            batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
        )
        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
        output = hidden_states.reshape(batch_size, -1, num_frames, height, width)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`(batch

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
22
23
24
25
26
27
28
29
30
31
32
33
34
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821