Skip to content

SkyReelsV2Transformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in SkyReels-V2 by the Skywork AI.

The model can be loaded with the following code snippet.

from mindone.diffusers import SkyReelsV2Transformer3DModel

transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", mindspore_dtype=ms.bfloat16)

mindone.diffusers.SkyReelsV2Transformer3DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin

A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.

PARAMETER DESCRIPTION
patch_size

3D patch dimensions for video embedding (t_patch, h_patch, w_patch).

TYPE: `Tuple[int]`, defaults to `(1, 2, 2)` DEFAULT: (1, 2, 2)

num_attention_heads

Fixed length for text embeddings.

TYPE: `int`, defaults to `16` DEFAULT: 16

attention_head_dim

The number of channels in each head.

TYPE: `int`, defaults to `128` DEFAULT: 128

in_channels

The number of channels in the input.

TYPE: `int`, defaults to `16` DEFAULT: 16

out_channels

The number of channels in the output.

TYPE: `int`, defaults to `16` DEFAULT: 16

text_dim

Input dimension for text embeddings.

TYPE: `int`, defaults to `4096` DEFAULT: 4096

freq_dim

Dimension for sinusoidal time embeddings.

TYPE: `int`, defaults to `256` DEFAULT: 256

ffn_dim

Intermediate dimension in feed-forward network.

TYPE: `int`, defaults to `8192` DEFAULT: 8192

num_layers

The number of layers of transformer blocks to use.

TYPE: `int`, defaults to `32` DEFAULT: 32

window_size

Window size for local attention (-1 indicates global attention).

TYPE: `Tuple[int]`, defaults to `(-1, -1)`

cross_attn_norm

Enable cross-attention normalization.

TYPE: `bool`, defaults to `True` DEFAULT: True

qk_norm

Enable query/key normalization.

TYPE: `str`, *optional*, defaults to `"rms_norm_across_heads"` DEFAULT: 'rms_norm_across_heads'

eps

Epsilon value for normalization layers.

TYPE: `float`, defaults to `1e-6` DEFAULT: 1e-06

inject_sample_info

Whether to inject sample information into the model.

TYPE: `bool`, defaults to `False` DEFAULT: False

image_dim

The dimension of the image embeddings.

TYPE: `int`, *optional* DEFAULT: None

added_kv_proj_dim

The dimension of the added key/value projection.

TYPE: `int`, *optional* DEFAULT: None

rope_max_seq_len

The maximum sequence length for the rotary embeddings.

TYPE: `int`, defaults to `1024` DEFAULT: 1024

pos_embed_seq_len

The sequence length for the positional embeddings.

TYPE: `int`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/transformers/transformer_skyreels_v2.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    r"""
    A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.

    Args:
        patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
            3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
        num_attention_heads (`int`, defaults to `16`):
            Fixed length for text embeddings.
        attention_head_dim (`int`, defaults to `128`):
            The number of channels in each head.
        in_channels (`int`, defaults to `16`):
            The number of channels in the input.
        out_channels (`int`, defaults to `16`):
            The number of channels in the output.
        text_dim (`int`, defaults to `4096`):
            Input dimension for text embeddings.
        freq_dim (`int`, defaults to `256`):
            Dimension for sinusoidal time embeddings.
        ffn_dim (`int`, defaults to `8192`):
            Intermediate dimension in feed-forward network.
        num_layers (`int`, defaults to `32`):
            The number of layers of transformer blocks to use.
        window_size (`Tuple[int]`, defaults to `(-1, -1)`):
            Window size for local attention (-1 indicates global attention).
        cross_attn_norm (`bool`, defaults to `True`):
            Enable cross-attention normalization.
        qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
            Enable query/key normalization.
        eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
        inject_sample_info (`bool`, defaults to `False`):
            Whether to inject sample information into the model.
        image_dim (`int`, *optional*):
            The dimension of the image embeddings.
        added_kv_proj_dim (`int`, *optional*):
            The dimension of the added key/value projection.
        rope_max_seq_len (`int`, defaults to `1024`):
            The maximum sequence length for the rotary embeddings.
        pos_embed_seq_len (`int`, *optional*):
            The sequence length for the positional embeddings.
    """

    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
    _no_split_modules = ["SkyReelsV2TransformerBlock"]
    _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
    _keys_to_ignore_on_load_unexpected = ["norm_added_q"]

    @register_to_config
    def __init__(
        self,
        patch_size: Tuple[int] = (1, 2, 2),
        num_attention_heads: int = 16,
        attention_head_dim: int = 128,
        in_channels: int = 16,
        out_channels: int = 16,
        text_dim: int = 4096,
        freq_dim: int = 256,
        ffn_dim: int = 8192,
        num_layers: int = 32,
        cross_attn_norm: bool = True,
        qk_norm: Optional[str] = "rms_norm_across_heads",
        eps: float = 1e-6,
        image_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
        rope_max_seq_len: int = 1024,
        pos_embed_seq_len: Optional[int] = None,
        inject_sample_info: bool = False,
        num_frame_per_block: int = 1,
    ) -> None:
        super().__init__()

        inner_dim = num_attention_heads * attention_head_dim
        out_channels = out_channels or in_channels

        # 1. Patch & position embedding
        self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
        self.patch_embedding = mint.nn.Conv3d(
            in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size)
        )

        # 2. Condition embeddings
        # image_embedding_dim=1280 for I2V model
        self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
            dim=inner_dim,
            time_freq_dim=freq_dim,
            time_proj_dim=inner_dim * 6,
            text_embed_dim=text_dim,
            image_embed_dim=image_dim,
            pos_embed_seq_len=pos_embed_seq_len,
        )

        # 3. Transformer blocks
        self.blocks = nn.CellList(
            [
                SkyReelsV2TransformerBlock(
                    inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
                )
                for _ in range(num_layers)
            ]
        )

        # 4. Output norm & projection
        self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
        self.proj_out = mint.nn.Linear(inner_dim, out_channels * math.prod(patch_size))
        self.scale_shift_table = ms.Parameter(mint.randn(1, 2, inner_dim) / inner_dim**0.5, name="scale_shift_table")

        if inject_sample_info:
            self.fps_embedding = mint.nn.Embedding(2, inner_dim)
            self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        encoder_hidden_states_image: Optional[ms.Tensor] = None,
        enable_diffusion_forcing: bool = False,
        fps: Optional[ms.Tensor] = None,
        return_dict: bool = False,
        attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Union[ms.Tensor, Dict[str, ms.Tensor]]:
        if attention_kwargs is not None:
            attention_kwargs = attention_kwargs.copy()

        if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
            logger.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.")

        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.config["patch_size"]
        post_patch_num_frames = num_frames // p_t
        post_patch_height = height // p_h
        post_patch_width = width // p_w

        rotary_emb = self.rope(hidden_states)

        hidden_states = self.patch_embedding(hidden_states)
        hidden_states = hidden_states.flatten(2).swapaxes(1, 2)

        causal_mask = None
        if self.config["num_frame_per_block"] > 1:
            block_num = post_patch_num_frames // self.config["num_frame_per_block"]
            range_tensor = mint.arange(block_num).repeat_interleave(self.config["num_frame_per_block"])
            causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1)  # f, f
            causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
            causal_mask = causal_mask.tile(
                (1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width)
            )
            causal_mask = causal_mask.reshape(
                post_patch_num_frames * post_patch_height * post_patch_width,
                post_patch_num_frames * post_patch_height * post_patch_width,
            )
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
            timestep, encoder_hidden_states, encoder_hidden_states_image
        )

        timestep_proj = unflatten(timestep_proj, -1, (6, -1))

        if encoder_hidden_states_image is not None:
            encoder_hidden_states = mint.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)

        if self.config["inject_sample_info"]:
            fps = ms.tensor(fps, dtype=ms.int64)

            fps_emb = self.fps_embedding(fps)
            if enable_diffusion_forcing:
                timestep_proj = timestep_proj + unflatten(self.fps_projection(fps_emb), 1, (6, -1)).tile(
                    (timestep.shape[1], 1, 1)
                )
            else:
                timestep_proj = timestep_proj + unflatten(self.fps_projection(fps_emb), 1, (6, -1))

        if enable_diffusion_forcing:
            b, f = timestep.shape
            temb = temb.view(b, f, 1, 1, -1)
            timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1)  # (b, f, 1, 1, 6, inner_dim)
            temb = temb.tile((1, 1, post_patch_height, post_patch_width, 1)).flatten(1, 3)
            timestep_proj = timestep_proj.tile((1, 1, post_patch_height, post_patch_width, 1, 1)).flatten(
                1, 3
            )  # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
            timestep_proj = timestep_proj.swapaxes(1, 2).contiguous()  # (b, 6, f * pp_h * pp_w, inner_dim)

        # 4. Transformer blocks
        for block in self.blocks:
            hidden_states = block(
                hidden_states,
                encoder_hidden_states,
                timestep_proj,
                rotary_emb,
                causal_mask,
            )

        shift, scale = None, None
        if temb.dim() == 2:
            # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
            # For models:
            # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
            # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
            # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
            shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
        elif temb.dim() == 3:
            # If temb is 3D, we assume it has 2-D time embedding values for each batch.
            # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
            # For models:
            # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
            # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
            # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
            shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
            shift, scale = shift.squeeze(1), scale.squeeze(1)

        hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)

        hidden_states = self.proj_out(hidden_states)

        hidden_states = hidden_states.reshape(
            batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
        )
        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

    def _set_ar_attention(self, causal_block_size: int):
        self.register_to_config(num_frame_per_block=causal_block_size)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`(batch

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821