Skip to content

SkyReelsV2Transformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in SkyReels-V2 by the Skywork AI.

The model can be loaded with the following code snippet.

from mindone.diffusers import SkyReelsV2Transformer3DModel

transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", mindspore_dtype=ms.bfloat16)

mindone.diffusers.SkyReelsV2Transformer3DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin

A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.

PARAMETER DESCRIPTION
patch_size

3D patch dimensions for video embedding (t_patch, h_patch, w_patch).

TYPE: `Tuple[int]`, defaults to `(1, 2, 2)` DEFAULT: (1, 2, 2)

num_attention_heads

Fixed length for text embeddings.

TYPE: `int`, defaults to `16` DEFAULT: 16

attention_head_dim

The number of channels in each head.

TYPE: `int`, defaults to `128` DEFAULT: 128

in_channels

The number of channels in the input.

TYPE: `int`, defaults to `16` DEFAULT: 16

out_channels

The number of channels in the output.

TYPE: `int`, defaults to `16` DEFAULT: 16

text_dim

Input dimension for text embeddings.

TYPE: `int`, defaults to `4096` DEFAULT: 4096

freq_dim

Dimension for sinusoidal time embeddings.

TYPE: `int`, defaults to `256` DEFAULT: 256

ffn_dim

Intermediate dimension in feed-forward network.

TYPE: `int`, defaults to `8192` DEFAULT: 8192

num_layers

The number of layers of transformer blocks to use.

TYPE: `int`, defaults to `32` DEFAULT: 32

window_size

Window size for local attention (-1 indicates global attention).

TYPE: `Tuple[int]`, defaults to `(-1, -1)`

cross_attn_norm

Enable cross-attention normalization.

TYPE: `bool`, defaults to `True` DEFAULT: True

qk_norm

Enable query/key normalization.

TYPE: `str`, *optional*, defaults to `"rms_norm_across_heads"` DEFAULT: 'rms_norm_across_heads'

eps

Epsilon value for normalization layers.

TYPE: `float`, defaults to `1e-6` DEFAULT: 1e-06

inject_sample_info

Whether to inject sample information into the model.

TYPE: `bool`, defaults to `False` DEFAULT: False

image_dim

The dimension of the image embeddings.

TYPE: `int`, *optional* DEFAULT: None

added_kv_proj_dim

The dimension of the added key/value projection.

TYPE: `int`, *optional* DEFAULT: None

rope_max_seq_len

The maximum sequence length for the rotary embeddings.

TYPE: `int`, defaults to `1024` DEFAULT: 1024

pos_embed_seq_len

The sequence length for the positional embeddings.

TYPE: `int`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/transformers/transformer_skyreels_v2.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
class SkyReelsV2Transformer3DModel(
    ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
    r"""
    A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.

    Args:
        patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
            3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
        num_attention_heads (`int`, defaults to `16`):
            Fixed length for text embeddings.
        attention_head_dim (`int`, defaults to `128`):
            The number of channels in each head.
        in_channels (`int`, defaults to `16`):
            The number of channels in the input.
        out_channels (`int`, defaults to `16`):
            The number of channels in the output.
        text_dim (`int`, defaults to `4096`):
            Input dimension for text embeddings.
        freq_dim (`int`, defaults to `256`):
            Dimension for sinusoidal time embeddings.
        ffn_dim (`int`, defaults to `8192`):
            Intermediate dimension in feed-forward network.
        num_layers (`int`, defaults to `32`):
            The number of layers of transformer blocks to use.
        window_size (`Tuple[int]`, defaults to `(-1, -1)`):
            Window size for local attention (-1 indicates global attention).
        cross_attn_norm (`bool`, defaults to `True`):
            Enable cross-attention normalization.
        qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
            Enable query/key normalization.
        eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
        inject_sample_info (`bool`, defaults to `False`):
            Whether to inject sample information into the model.
        image_dim (`int`, *optional*):
            The dimension of the image embeddings.
        added_kv_proj_dim (`int`, *optional*):
            The dimension of the added key/value projection.
        rope_max_seq_len (`int`, defaults to `1024`):
            The maximum sequence length for the rotary embeddings.
        pos_embed_seq_len (`int`, *optional*):
            The sequence length for the positional embeddings.
    """

    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
    _no_split_modules = ["SkyReelsV2TransformerBlock"]
    _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
    _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
    _repeated_blocks = ["SkyReelsV2TransformerBlock"]

    @register_to_config
    def __init__(
        self,
        patch_size: Tuple[int, ...] = (1, 2, 2),
        num_attention_heads: int = 16,
        attention_head_dim: int = 128,
        in_channels: int = 16,
        out_channels: int = 16,
        text_dim: int = 4096,
        freq_dim: int = 256,
        ffn_dim: int = 8192,
        num_layers: int = 32,
        cross_attn_norm: bool = True,
        qk_norm: Optional[str] = "rms_norm_across_heads",
        eps: float = 1e-6,
        image_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
        rope_max_seq_len: int = 1024,
        pos_embed_seq_len: Optional[int] = None,
        inject_sample_info: bool = False,
        num_frame_per_block: int = 1,
    ) -> None:
        super().__init__()

        inner_dim = num_attention_heads * attention_head_dim
        out_channels = out_channels or in_channels

        # 1. Patch & position embedding
        self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
        self.patch_embedding = mint.nn.Conv3d(
            in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size)
        )

        # 2. Condition embeddings
        # image_embedding_dim=1280 for I2V model
        self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
            dim=inner_dim,
            time_freq_dim=freq_dim,
            time_proj_dim=inner_dim * 6,
            text_embed_dim=text_dim,
            image_embed_dim=image_dim,
            pos_embed_seq_len=pos_embed_seq_len,
        )

        # 3. Transformer blocks
        self.blocks = nn.CellList(
            [
                SkyReelsV2TransformerBlock(
                    inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
                )
                for _ in range(num_layers)
            ]
        )

        # 4. Output norm & projection
        self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
        self.proj_out = mint.nn.Linear(inner_dim, out_channels * math.prod(patch_size))
        self.scale_shift_table = ms.Parameter(mint.randn(1, 2, inner_dim) / inner_dim**0.5, name="scale_shift_table")

        if inject_sample_info:
            self.fps_embedding = mint.nn.Embedding(2, inner_dim)
            self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        encoder_hidden_states_image: Optional[ms.Tensor] = None,
        enable_diffusion_forcing: bool = False,
        fps: Optional[ms.Tensor] = None,
        return_dict: bool = False,
        attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Union[ms.Tensor, Dict[str, ms.Tensor]]:
        if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
            # weight the lora layers by setting `lora_scale` for each PEFT layer here
            # and remove `lora_scale` from each PEFT layer at the end.
            # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
            raise RuntimeError(
                f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. "
                f"However it's not allowed in on-the-fly model forwarding. "
                f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
                f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
                f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
            )

        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.config["patch_size"]
        post_patch_num_frames = num_frames // p_t
        post_patch_height = height // p_h
        post_patch_width = width // p_w

        rotary_emb = self.rope(hidden_states)

        hidden_states = self.patch_embedding(hidden_states)
        hidden_states = hidden_states.flatten(2).swapaxes(1, 2)

        causal_mask = None
        if self.config["num_frame_per_block"] > 1:
            block_num = post_patch_num_frames // self.config["num_frame_per_block"]
            range_tensor = mint.arange(block_num).repeat_interleave(self.config["num_frame_per_block"])
            causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1)  # f, f
            causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
            causal_mask = causal_mask.tile(
                (1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width)
            )
            causal_mask = causal_mask.reshape(
                post_patch_num_frames * post_patch_height * post_patch_width,
                post_patch_num_frames * post_patch_height * post_patch_width,
            )
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
            timestep, encoder_hidden_states, encoder_hidden_states_image
        )

        timestep_proj = unflatten(timestep_proj, -1, (6, -1))

        if encoder_hidden_states_image is not None:
            encoder_hidden_states = mint.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)

        if self.config["inject_sample_info"]:
            fps = ms.tensor(fps, dtype=ms.int64)

            fps_emb = self.fps_embedding(fps)
            if enable_diffusion_forcing:
                timestep_proj = timestep_proj + unflatten(self.fps_projection(fps_emb), 1, (6, -1)).tile(
                    (timestep.shape[1], 1, 1)
                )
            else:
                timestep_proj = timestep_proj + unflatten(self.fps_projection(fps_emb), 1, (6, -1))

        if enable_diffusion_forcing:
            b, f = timestep.shape
            temb = temb.view(b, f, 1, 1, -1)
            timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1)  # (b, f, 1, 1, 6, inner_dim)
            temb = temb.tile((1, 1, post_patch_height, post_patch_width, 1)).flatten(1, 3)
            timestep_proj = timestep_proj.tile((1, 1, post_patch_height, post_patch_width, 1, 1)).flatten(
                1, 3
            )  # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
            timestep_proj = timestep_proj.swapaxes(1, 2).contiguous()  # (b, 6, f * pp_h * pp_w, inner_dim)

        # 4. Transformer blocks
        for block in self.blocks:
            hidden_states = block(
                hidden_states,
                encoder_hidden_states,
                timestep_proj,
                rotary_emb,
                causal_mask,
            )

        shift, scale = None, None
        if temb.dim() == 2:
            # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
            # For models:
            # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
            # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
            # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
            shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
        elif temb.dim() == 3:
            # If temb is 3D, we assume it has 2-D time embedding values for each batch.
            # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
            # For models:
            # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
            # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
            shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
            shift, scale = shift.squeeze(1), scale.squeeze(1)

        hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)

        hidden_states = self.proj_out(hidden_states)

        hidden_states = hidden_states.reshape(
            batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
        )
        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

    def _set_ar_attention(self, causal_block_size: int):
        self.register_to_config(num_frame_per_block=causal_block_size)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821