Skip to content

EasyAnimateTransformer3DModel

A Diffusion Transformer model for 3D data from EasyAnimate was introduced by Alibaba PAI.

The model can be loaded with the following code snippet.

from mindone.diffusers import EasyAnimateTransformer3DModel

transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", mindspore_dtype=ms.float16)

mindone.diffusers.EasyAnimateTransformer3DModel

Bases: ModelMixin, ConfigMixin

A Transformer model for video-like data in EasyAnimate.

PARAMETER DESCRIPTION
num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, defaults to `48` DEFAULT: 48

attention_head_dim

The number of channels in each head.

TYPE: `int`, defaults to `64` DEFAULT: 64

in_channels

The number of channels in the input.

TYPE: `int`, defaults to `16` DEFAULT: None

out_channels

The number of channels in the output.

TYPE: `int`, *optional*, defaults to `16` DEFAULT: None

patch_size

The size of the patches to use in the patch embedding layer.

TYPE: `int`, defaults to `2` DEFAULT: None

sample_width

The width of the input latents.

TYPE: `int`, defaults to `90` DEFAULT: 90

sample_height

The height of the input latents.

TYPE: `int`, defaults to `60` DEFAULT: 60

activation_fn

Activation function to use in feed-forward.

TYPE: `str`, defaults to `"gelu-approximate"` DEFAULT: 'gelu-approximate'

timestep_activation_fn

Activation function to use when generating the timestep embeddings.

TYPE: `str`, defaults to `"silu"` DEFAULT: 'silu'

num_layers

The number of layers of Transformer blocks to use.

TYPE: `int`, defaults to `30` DEFAULT: 48

mmdit_layers

The number of layers of Multi Modal Transformer blocks to use.

TYPE: `int`, defaults to `1000` DEFAULT: 48

dropout

The dropout probability to use.

TYPE: `float`, defaults to `0.0` DEFAULT: 0.0

time_embed_dim

Output dimension of timestep embeddings.

TYPE: `int`, defaults to `512` DEFAULT: 512

text_embed_dim

Input dimension of text embeddings from the text encoder.

TYPE: `int`, defaults to `4096` DEFAULT: 3584

norm_eps

The epsilon value to use in normalization layers.

TYPE: `float`, defaults to `1e-5` DEFAULT: 1e-05

norm_elementwise_affine

Whether to use elementwise affine in normalization layers.

TYPE: `bool`, defaults to `True` DEFAULT: True

flip_sin_to_cos

Whether to flip the sin to cos in the time embedding.

TYPE: `bool`, defaults to `True` DEFAULT: True

time_position_encoding_type

Type of time position encoding.

TYPE: `str`, defaults to `3d_rope` DEFAULT: '3d_rope'

after_norm

Flag to apply normalization after.

TYPE: `bool`, defaults to `False` DEFAULT: False

resize_inpaint_mask_directly

Flag to resize inpaint mask directly.

TYPE: `bool`, defaults to `True` DEFAULT: True

enable_text_attention_mask

Flag to enable text attention mask.

TYPE: `bool`, defaults to `True` DEFAULT: True

add_noise_in_inpaint_model

Flag to add noise in inpaint model.

TYPE: `bool`, defaults to `False` DEFAULT: True

Source code in mindone/diffusers/models/transformers/transformer_easyanimate.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
    """
    A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).

    Parameters:
        num_attention_heads (`int`, defaults to `48`):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`, defaults to `64`):
            The number of channels in each head.
        in_channels (`int`, defaults to `16`):
            The number of channels in the input.
        out_channels (`int`, *optional*, defaults to `16`):
            The number of channels in the output.
        patch_size (`int`, defaults to `2`):
            The size of the patches to use in the patch embedding layer.
        sample_width (`int`, defaults to `90`):
            The width of the input latents.
        sample_height (`int`, defaults to `60`):
            The height of the input latents.
        activation_fn (`str`, defaults to `"gelu-approximate"`):
            Activation function to use in feed-forward.
        timestep_activation_fn (`str`, defaults to `"silu"`):
            Activation function to use when generating the timestep embeddings.
        num_layers (`int`, defaults to `30`):
            The number of layers of Transformer blocks to use.
        mmdit_layers (`int`, defaults to `1000`):
            The number of layers of Multi Modal Transformer blocks to use.
        dropout (`float`, defaults to `0.0`):
            The dropout probability to use.
        time_embed_dim (`int`, defaults to `512`):
            Output dimension of timestep embeddings.
        text_embed_dim (`int`, defaults to `4096`):
            Input dimension of text embeddings from the text encoder.
        norm_eps (`float`, defaults to `1e-5`):
            The epsilon value to use in normalization layers.
        norm_elementwise_affine (`bool`, defaults to `True`):
            Whether to use elementwise affine in normalization layers.
        flip_sin_to_cos (`bool`, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        time_position_encoding_type (`str`, defaults to `3d_rope`):
            Type of time position encoding.
        after_norm (`bool`, defaults to `False`):
            Flag to apply normalization after.
        resize_inpaint_mask_directly (`bool`, defaults to `True`):
            Flag to resize inpaint mask directly.
        enable_text_attention_mask (`bool`, defaults to `True`):
            Flag to enable text attention mask.
        add_noise_in_inpaint_model (`bool`, defaults to `False`):
            Flag to add noise in inpaint model.
    """

    _supports_gradient_checkpointing = True
    _no_split_modules = ["EasyAnimateTransformerBlock"]
    _skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 48,
        attention_head_dim: int = 64,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        patch_size: Optional[int] = None,
        sample_width: int = 90,
        sample_height: int = 60,
        activation_fn: str = "gelu-approximate",
        timestep_activation_fn: str = "silu",
        freq_shift: int = 0,
        num_layers: int = 48,
        mmdit_layers: int = 48,
        dropout: float = 0.0,
        time_embed_dim: int = 512,
        add_norm_text_encoder: bool = False,
        text_embed_dim: int = 3584,
        text_embed_dim_t5: int = None,
        norm_eps: float = 1e-5,
        norm_elementwise_affine: bool = True,
        flip_sin_to_cos: bool = True,
        time_position_encoding_type: str = "3d_rope",
        after_norm=False,
        resize_inpaint_mask_directly: bool = True,
        enable_text_attention_mask: bool = True,
        add_noise_in_inpaint_model: bool = True,
    ):
        super().__init__()
        inner_dim = num_attention_heads * attention_head_dim

        # 1. Timestep embedding
        self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
        self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
        self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)

        # 2. Patch embedding
        self.proj = mint.nn.Conv2d(
            in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
        )

        # 3. Text refined embedding
        self.text_proj = None
        self.text_proj_t5 = None
        if not add_norm_text_encoder:
            self.text_proj = nn.Dense(text_embed_dim, inner_dim)
            if text_embed_dim_t5 is not None:
                self.text_proj_t5 = nn.Dense(text_embed_dim_t5, inner_dim)
        else:
            self.text_proj = nn.SequentialCell(
                RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Dense(text_embed_dim, inner_dim)
            )
            if text_embed_dim_t5 is not None:
                self.text_proj_t5 = nn.SequentialCell(
                    RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Dense(text_embed_dim_t5, inner_dim)
                )

        # 4. Transformer blocks
        self.transformer_blocks = nn.CellList(
            [
                EasyAnimateTransformerBlock(
                    dim=inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    time_embed_dim=time_embed_dim,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                    after_norm=after_norm,
                    is_mmdit_block=True if _ < mmdit_layers else False,
                )
                for _ in range(num_layers)
            ]
        )
        self.norm_final = mint.nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)

        # 5. Output norm & projection
        self.norm_out = AdaLayerNorm(
            embedding_dim=time_embed_dim,
            output_dim=2 * inner_dim,
            norm_elementwise_affine=norm_elementwise_affine,
            norm_eps=norm_eps,
            chunk_dim=1,
        )
        self.proj_out = nn.Dense(inner_dim, patch_size * patch_size * out_channels)

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        timestep_cond: Optional[ms.Tensor] = None,
        encoder_hidden_states: Optional[ms.Tensor] = None,
        encoder_hidden_states_t5: Optional[ms.Tensor] = None,
        inpaint_latents: Optional[ms.Tensor] = None,
        control_latents: Optional[ms.Tensor] = None,
        return_dict: bool = False,
    ) -> Union[Tuple[ms.Tensor], Transformer2DModelOutput]:
        batch_size, channels, video_length, height, width = hidden_states.shape
        p = self.config["patch_size"]
        post_patch_height = height // p
        post_patch_width = width // p

        # 1. Time embedding
        temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
        temb = self.time_embedding(temb, timestep_cond)
        image_rotary_emb = self.rope_embedding(hidden_states)

        # 2. Patch embedding
        if inpaint_latents is not None:
            hidden_states = mint.concat([hidden_states, inpaint_latents], 1)
        if control_latents is not None:
            hidden_states = mint.concat([hidden_states, control_latents], 1)

        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)  # [B, C, F, H, W] -> [BF, C, H, W]
        hidden_states = self.proj(hidden_states)
        # hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
        #     0, 2, 1, 3, 4
        # )  # [BF, C, H, W] -> [B, F, C, H, W]
        hidden_states = hidden_states.reshape((batch_size, -1) + hidden_states.shape[1:]).permute(
            0, 2, 1, 3, 4
        )  # [BF, C, H, W] -> [B, F, C, H, W]
        hidden_states = hidden_states.flatten(2, 4).swapaxes(1, 2)  # [B, F, C, H, W] -> [B, FHW, C]

        # 3. Text embedding
        encoder_hidden_states = self.text_proj(encoder_hidden_states)
        if encoder_hidden_states_t5 is not None:
            encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
            encoder_hidden_states = mint.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()

        # 4. Transformer blocks
        for block in self.transformer_blocks:
            hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, image_rotary_emb)

        hidden_states = self.norm_final(hidden_states)

        # 5. Output norm & projection
        hidden_states = self.norm_out(hidden_states, temb=temb)
        hidden_states = self.proj_out(hidden_states)

        # 6. Unpatchify
        p = self.config["patch_size"]
        output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
        output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`(batch

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
22
23
24
25
26
27
28
29
30
31
32
33
34
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821