Skip to content

CosmosTransformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in Cosmos World Foundation Model Platform for Physical AI by NVIDIA.

The model can be loaded with the following code snippet.

import mindspore
from mindone.diffusers import CosmosTransformer3DModel

transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", mindspore_dtype=mindspore.bfloat16)

mindone.diffusers.CosmosTransformer3DModel

Bases: ModelMixin, ConfigMixin

A Transformer model for video-like data used in Cosmos.

PARAMETER DESCRIPTION
in_channels

The number of channels in the input.

TYPE: `int`, defaults to `16` DEFAULT: 16

out_channels

The number of channels in the output.

TYPE: `int`, defaults to `16` DEFAULT: 16

num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, defaults to `32` DEFAULT: 32

attention_head_dim

The number of channels in each attention head.

TYPE: `int`, defaults to `128` DEFAULT: 128

num_layers

The number of layers of transformer blocks to use.

TYPE: `int`, defaults to `28` DEFAULT: 28

mlp_ratio

The ratio of the hidden layer size to the input size in the feedforward network.

TYPE: `float`, defaults to `4.0` DEFAULT: 4.0

text_embed_dim

Input dimension of text embeddings from the text encoder.

TYPE: `int`, defaults to `4096` DEFAULT: 1024

adaln_lora_dim

The hidden dimension of the Adaptive LayerNorm LoRA layer.

TYPE: `int`, defaults to `256` DEFAULT: 256

max_size

The maximum size of the input latent tensors in the temporal, height, and width dimensions.

TYPE: `Tuple[int, int, int]`, defaults to `(128, 240, 240)` DEFAULT: (128, 240, 240)

patch_size

The patch size to use for patchifying the input latent tensors in the temporal, height, and width dimensions.

TYPE: `Tuple[int, int, int]`, defaults to `(1, 2, 2)` DEFAULT: (1, 2, 2)

rope_scale

The scaling factor to use for RoPE in the temporal, height, and width dimensions.

TYPE: `Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)` DEFAULT: (2.0, 1.0, 1.0)

concat_padding_mask

Whether to concatenate the padding mask to the input latent tensors.

TYPE: `bool`, defaults to `True` DEFAULT: True

extra_pos_embed_type

The type of extra positional embeddings to use. Can be one of None or learnable.

TYPE: `str`, *optional*, defaults to `learnable` DEFAULT: 'learnable'

Source code in mindone/diffusers/models/transformers/transformer_cosmos.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
    r"""
    A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).

    Args:
        in_channels (`int`, defaults to `16`):
            The number of channels in the input.
        out_channels (`int`, defaults to `16`):
            The number of channels in the output.
        num_attention_heads (`int`, defaults to `32`):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`, defaults to `128`):
            The number of channels in each attention head.
        num_layers (`int`, defaults to `28`):
            The number of layers of transformer blocks to use.
        mlp_ratio (`float`, defaults to `4.0`):
            The ratio of the hidden layer size to the input size in the feedforward network.
        text_embed_dim (`int`, defaults to `4096`):
            Input dimension of text embeddings from the text encoder.
        adaln_lora_dim (`int`, defaults to `256`):
            The hidden dimension of the Adaptive LayerNorm LoRA layer.
        max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
            The maximum size of the input latent tensors in the temporal, height, and width dimensions.
        patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
            The patch size to use for patchifying the input latent tensors in the temporal, height, and width
            dimensions.
        rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
            The scaling factor to use for RoPE in the temporal, height, and width dimensions.
        concat_padding_mask (`bool`, defaults to `True`):
            Whether to concatenate the padding mask to the input latent tensors.
        extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
            The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
    """

    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
    _no_split_modules = ["CosmosTransformerBlock"]
    _keep_in_fp32_modules = ["learnable_pos_embed"]

    @register_to_config
    def __init__(
        self,
        in_channels: int = 16,
        out_channels: int = 16,
        num_attention_heads: int = 32,
        attention_head_dim: int = 128,
        num_layers: int = 28,
        mlp_ratio: float = 4.0,
        text_embed_dim: int = 1024,
        adaln_lora_dim: int = 256,
        max_size: Tuple[int, int, int] = (128, 240, 240),
        patch_size: Tuple[int, int, int] = (1, 2, 2),
        rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
        concat_padding_mask: bool = True,
        extra_pos_embed_type: Optional[str] = "learnable",
    ) -> None:
        super().__init__()
        hidden_size = num_attention_heads * attention_head_dim

        # 1. Patch Embedding
        patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
        self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)

        # 2. Positional Embedding
        self.rope = CosmosRotaryPosEmbed(
            hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
        )

        self.learnable_pos_embed = None
        if extra_pos_embed_type == "learnable":
            self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
                hidden_size=hidden_size,
                max_size=max_size,
                patch_size=patch_size,
            )

        # 3. Time Embedding
        self.time_embed = CosmosEmbedding(hidden_size, hidden_size)

        # 4. Transformer Blocks
        self.transformer_blocks = nn.CellList(
            [
                CosmosTransformerBlock(
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    cross_attention_dim=text_embed_dim,
                    mlp_ratio=mlp_ratio,
                    adaln_lora_dim=adaln_lora_dim,
                    qk_norm="rms_norm",
                    out_bias=False,
                )
                for _ in range(num_layers)
            ]
        )

        # 5. Output norm & projection
        self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
        self.proj_out = mint.nn.Linear(
            hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
        )

        self.gradient_checkpointing = False

        self.p_t, self.p_h, self.p_w = self.config.patch_size
        self.concat_padding_mask = self.config.concat_padding_mask
        self.extra_pos_embed_type = self.config.extra_pos_embed_type

    def construct(
        self,
        hidden_states: ms.tensor,
        timestep: ms.tensor,
        encoder_hidden_states: ms.tensor,
        attention_mask: Optional[ms.tensor] = None,
        fps: Optional[int] = None,
        condition_mask: Optional[ms.tensor] = None,
        padding_mask: Optional[ms.tensor] = None,
        return_dict: bool = False,
    ) -> ms.tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape

        # 1. Concatenate padding mask if needed & prepare attention mask
        if condition_mask is not None:
            hidden_states = mint.cat([hidden_states, condition_mask], dim=1)

        if self.concat_padding_mask:
            # padding_mask = transforms.functional.resize(
            #     padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
            # )
            padding_mask = mint.functional.interpolate(
                padding_mask, size=list(hidden_states.shape[-2:]), mode="nearest"
            )
            hidden_states = mint.cat(
                [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
            )

        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, S]

        # 2. Generate positional embeddings
        image_rotary_emb = self.rope(hidden_states, fps=fps)
        extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.extra_pos_embed_type else None

        # 3. Patchify input
        post_patch_num_frames = num_frames // self.p_t
        post_patch_height = height // self.p_h
        post_patch_width = width // self.p_w
        hidden_states = self.patch_embed(hidden_states)
        hidden_states = hidden_states.flatten(1, 3)  # [B, T, H, W, C] -> [B, THW, C]

        # 4. Timestep embeddings
        temb, embedded_timestep = None, None
        if timestep.ndim == 1:
            temb, embedded_timestep = self.time_embed(hidden_states, timestep)
        elif timestep.ndim == 5:
            assert timestep.shape == (
                batch_size,
                1,
                num_frames,
                1,
                1,
            ), f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
            timestep = timestep.flatten()
            temb, embedded_timestep = self.time_embed(hidden_states, timestep)
            # We can do this because num_frames == post_patch_num_frames, as p_t is 1
            temb, embedded_timestep = (
                x.view(batch_size, post_patch_num_frames, 1, 1, -1)
                .expand((-1, -1, post_patch_height, post_patch_width, -1))
                .flatten(1, 3)
                for x in (temb, embedded_timestep)
            )  # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
        else:
            assert False

        # 5. Transformer blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                embedded_timestep=embedded_timestep,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                extra_pos_emb=extra_pos_emb,
                attention_mask=attention_mask,
            )

        # 6. Output norm & projection & unpatchify
        hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
        hidden_states = self.proj_out(hidden_states)
        hidden_states = unflatten(hidden_states, 2, (self.p_h, self.p_w, self.p_t, -1))
        hidden_states = unflatten(hidden_states, 1, (post_patch_num_frames, post_patch_height, post_patch_width))
        # Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
        # Another few hours of sanity lost to the void.
        hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
        hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

        if not return_dict:
            return (hidden_states,)

        return Transformer2DModelOutput(sample=hidden_states)