Skip to content

LTXVideoTransformer3DModel

A Diffusion Transformer model for 3D data from LTX was introduced by Lightricks.

The model can be loaded with the following code snippet.

from mindone.diffusers import LTXVideoTransformer3DModel
import mindspore as ms

transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", mindspore_dtype=ms.bfloat16)

mindone.diffusers.LTXVideoTransformer3DModel

Bases: ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin

A Transformer model for video-like data used in LTX.

PARAMETER DESCRIPTION
in_channels

The number of channels in the input.

TYPE: `int`, defaults to `128` DEFAULT: 128

out_channels

The number of channels in the output.

TYPE: `int`, defaults to `128` DEFAULT: 128

patch_size

The size of the spatial patches to use in the patch embedding layer.

TYPE: `int`, defaults to `1` DEFAULT: 1

patch_size_t

The size of the tmeporal patches to use in the patch embedding layer.

TYPE: `int`, defaults to `1` DEFAULT: 1

num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, defaults to `32` DEFAULT: 32

attention_head_dim

The number of channels in each head.

TYPE: `int`, defaults to `64` DEFAULT: 64

cross_attention_dim

The number of channels for cross attention heads.

TYPE: `int`, defaults to `2048 ` DEFAULT: 2048

num_layers

The number of layers of Transformer blocks to use.

TYPE: `int`, defaults to `28` DEFAULT: 28

activation_fn

Activation function to use in feed-forward.

TYPE: `str`, defaults to `"gelu-approximate"` DEFAULT: 'gelu-approximate'

qk_norm

The normalization layer to use.

TYPE: `str`, defaults to `"rms_norm_across_heads"` DEFAULT: 'rms_norm_across_heads'

Source code in mindone/diffusers/models/transformers/transformer_ltx.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
class LTXVideoTransformer3DModel(
    ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
):
    r"""
    A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

    Args:
        in_channels (`int`, defaults to `128`):
            The number of channels in the input.
        out_channels (`int`, defaults to `128`):
            The number of channels in the output.
        patch_size (`int`, defaults to `1`):
            The size of the spatial patches to use in the patch embedding layer.
        patch_size_t (`int`, defaults to `1`):
            The size of the tmeporal patches to use in the patch embedding layer.
        num_attention_heads (`int`, defaults to `32`):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`, defaults to `64`):
            The number of channels in each head.
        cross_attention_dim (`int`, defaults to `2048 `):
            The number of channels for cross attention heads.
        num_layers (`int`, defaults to `28`):
            The number of layers of Transformer blocks to use.
        activation_fn (`str`, defaults to `"gelu-approximate"`):
            Activation function to use in feed-forward.
        qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
            The normalization layer to use.
    """

    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["norm"]
    _repeated_blocks = ["LTXVideoTransformerBlock"]

    @register_to_config
    def __init__(
        self,
        in_channels: int = 128,
        out_channels: int = 128,
        patch_size: int = 1,
        patch_size_t: int = 1,
        num_attention_heads: int = 32,
        attention_head_dim: int = 64,
        cross_attention_dim: int = 2048,
        num_layers: int = 28,
        activation_fn: str = "gelu-approximate",
        qk_norm: str = "rms_norm_across_heads",
        norm_elementwise_affine: bool = False,
        norm_eps: float = 1e-6,
        caption_channels: int = 4096,
        attention_bias: bool = True,
        attention_out_bias: bool = True,
    ) -> None:
        super().__init__()

        out_channels = out_channels or in_channels
        inner_dim = num_attention_heads * attention_head_dim

        self.proj_in = mint.nn.Linear(in_channels, inner_dim)

        self.scale_shift_table = ms.Parameter(mint.randn(2, inner_dim) / inner_dim**0.5, name="scale_shift_table")
        self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)

        self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)

        self.rope = LTXVideoRotaryPosEmbed(
            dim=inner_dim,
            base_num_frames=20,
            base_height=2048,
            base_width=2048,
            patch_size=patch_size,
            patch_size_t=patch_size_t,
            theta=10000.0,
        )

        self.transformer_blocks = nn.CellList(
            [
                LTXVideoTransformerBlock(
                    dim=inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    cross_attention_dim=cross_attention_dim,
                    qk_norm=qk_norm,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    attention_out_bias=attention_out_bias,
                    eps=norm_eps,
                    elementwise_affine=norm_elementwise_affine,
                )
                for _ in range(num_layers)
            ]
        )

        self.norm_out = mint.nn.LayerNorm(inner_dim, eps=1e-6)
        self.proj_out = mint.nn.Linear(inner_dim, out_channels)

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        encoder_attention_mask: ms.Tensor,
        num_frames: Optional[int] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        rope_interpolation_scale: Optional[Union[Tuple[float, float, float], ms.Tensor]] = None,
        video_coords: Optional[ms.Tensor] = None,
        attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = False,
    ) -> ms.Tensor:
        if attention_kwargs is not None:
            attention_kwargs = attention_kwargs.copy()

        image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)

        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        batch_size = hidden_states.shape[0]
        hidden_states = self.proj_in(hidden_states)

        temb, embedded_timestep = self.time_embed(
            timestep.flatten(),
            batch_size=batch_size,
            hidden_dtype=hidden_states.dtype,
        )

        temb = temb.view(batch_size, -1, temb.shape[-1])
        embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])

        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                encoder_attention_mask=encoder_attention_mask,
            )

        scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
        shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

        hidden_states = self.norm_out(hidden_states)
        hidden_states = hidden_states * (1 + scale) + shift
        output = self.proj_out(hidden_states)

        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821