Skip to content

Lumina2Transformer2DModel

A Diffusion Transformer model for 3D video-like data was introduced in Lumina Image 2.0 by Alpha-VLLM.

mindone.diffusers.Lumina2Transformer2DModel

Bases: ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin

PARAMETER DESCRIPTION
sample_size

The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings.

TYPE: `int` DEFAULT: 128

patch_size

The size of each patch in the image. This parameter defines the resolution of patches fed into the model.

TYPE: `int`, *optional*, (`int`, *optional*, defaults to 2 DEFAULT: 2

in_channels

The number of input channels for the model. Typically, this matches the number of channels in the input images.

TYPE: `int`, *optional*, defaults to 4 DEFAULT: 16

hidden_size

The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations.

TYPE: `int`, *optional*, defaults to 4096 DEFAULT: 2304

num_layers

The number of layers in the model. This defines the depth of the neural network.

TYPE: `int`, *optional*, default to 32 DEFAULT: 26

num_attention_heads

The number of attention heads in each attention layer. This parameter specifies how many separate attention mechanisms are used.

TYPE: `int`, *optional*, defaults to 32 DEFAULT: 24

num_kv_heads

The number of key-value heads in the attention mechanism, if different from the number of attention heads. If None, it defaults to num_attention_heads.

TYPE: `int`, *optional*, defaults to 8 DEFAULT: 8

multiple_of

A factor that the hidden size should be a multiple of. This can help optimize certain hardware configurations.

TYPE: `int`, *optional*, defaults to 256 DEFAULT: 256

ffn_dim_multiplier

A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on the model configuration.

TYPE: `float`, *optional* DEFAULT: None

norm_eps

A small value added to the denominator for numerical stability in normalization layers.

TYPE: `float`, *optional*, defaults to 1e-5 DEFAULT: 1e-05

scaling_factor

A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the overall scale of the model's operations.

TYPE: `float`, *optional*, defaults to 1.0 DEFAULT: 1.0

Source code in mindone/diffusers/models/transformers/transformer_lumina2.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    r"""
    Lumina2NextDiT: Diffusion model with a Transformer backbone.

    Parameters:
        sample_size (`int`): The width of the latent images. This is fixed during training since
            it is used to learn a number of position embeddings.
        patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
            The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
        in_channels (`int`, *optional*, defaults to 4):
            The number of input channels for the model. Typically, this matches the number of channels in the input
            images.
        hidden_size (`int`, *optional*, defaults to 4096):
            The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
            hidden representations.
        num_layers (`int`, *optional*, default to 32):
            The number of layers in the model. This defines the depth of the neural network.
        num_attention_heads (`int`, *optional*, defaults to 32):
            The number of attention heads in each attention layer. This parameter specifies how many separate attention
            mechanisms are used.
        num_kv_heads (`int`, *optional*, defaults to 8):
            The number of key-value heads in the attention mechanism, if different from the number of attention heads.
            If None, it defaults to num_attention_heads.
        multiple_of (`int`, *optional*, defaults to 256):
            A factor that the hidden size should be a multiple of. This can help optimize certain hardware
            configurations.
        ffn_dim_multiplier (`float`, *optional*):
            A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
            the model configuration.
        norm_eps (`float`, *optional*, defaults to 1e-5):
            A small value added to the denominator for numerical stability in normalization layers.
        scaling_factor (`float`, *optional*, defaults to 1.0):
            A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
            overall scale of the model's operations.
    """

    _supports_gradient_checkpointing = True
    _no_split_modules = ["Lumina2TransformerBlock"]
    _skip_layerwise_casting_patterns = ["x_embedder", "norm"]

    @register_to_config
    def __init__(
        self,
        sample_size: int = 128,
        patch_size: int = 2,
        in_channels: int = 16,
        out_channels: Optional[int] = None,
        hidden_size: int = 2304,
        num_layers: int = 26,
        num_refiner_layers: int = 2,
        num_attention_heads: int = 24,
        num_kv_heads: int = 8,
        multiple_of: int = 256,
        ffn_dim_multiplier: Optional[float] = None,
        norm_eps: float = 1e-5,
        scaling_factor: float = 1.0,
        axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
        axes_lens: Tuple[int, int, int] = (300, 512, 512),
        cap_feat_dim: int = 1024,
    ) -> None:
        super().__init__()
        self.out_channels = out_channels or in_channels

        # 1. Positional, patch & conditional embeddings
        self.rope_embedder = Lumina2RotaryPosEmbed(
            theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
        )

        self.x_embedder = nn.Dense(in_channels=patch_size * patch_size * in_channels, out_channels=hidden_size)

        self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
            hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
        )

        # 2. Noise and context refinement blocks
        self.noise_refiner = nn.CellList(
            [
                Lumina2TransformerBlock(
                    hidden_size,
                    num_attention_heads,
                    num_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    modulation=True,
                )
                for _ in range(num_refiner_layers)
            ]
        )

        self.context_refiner = nn.CellList(
            [
                Lumina2TransformerBlock(
                    hidden_size,
                    num_attention_heads,
                    num_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    modulation=False,
                )
                for _ in range(num_refiner_layers)
            ]
        )

        # 3. Transformer blocks
        self.layers = nn.CellList(
            [
                Lumina2TransformerBlock(
                    hidden_size,
                    num_attention_heads,
                    num_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    modulation=True,
                )
                for _ in range(num_layers)
            ]
        )

        # 4. Output norm & projection
        self.norm_out = LuminaLayerNormContinuous(
            embedding_dim=hidden_size,
            conditioning_embedding_dim=min(hidden_size, 1024),
            elementwise_affine=False,
            eps=1e-6,
            bias=True,
            out_dim=patch_size * patch_size * self.out_channels,
        )

        self.gradient_checkpointing = False

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        encoder_attention_mask: ms.Tensor,
        attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = False,
    ) -> Union[ms.Tensor, Transformer2DModelOutput]:
        if attention_kwargs is not None:
            attention_kwargs = attention_kwargs.copy()

        if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
            # weight the lora layers by setting `lora_scale` for each PEFT layer here
            # and remove `lora_scale` from each PEFT layer at the end.
            # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
            raise RuntimeError(
                f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. "
                f"However it's not allowed in on-the-fly model forwarding. "
                f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
                f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
                f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
            )

        # 1. Condition, positional & patch embedding
        batch_size, _, height, width = hidden_states.shape

        temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)

        (
            hidden_states,
            context_rotary_emb,
            noise_rotary_emb,
            rotary_emb,
            encoder_seq_lengths,
            seq_lengths,
        ) = self.rope_embedder(hidden_states, encoder_attention_mask)

        hidden_states = self.x_embedder(hidden_states)

        # 2. Context & noise refinement
        for layer in self.context_refiner:
            encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)

        for layer in self.noise_refiner:
            hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)

        # 3. Joint Transformer blocks
        # TODO: the type of seq_lengths is tensor
        max_seq_len = mint.max(seq_lengths).item()
        # TODO: `set` may not be supported in graph mode
        use_mask = mint.unique(seq_lengths).shape[0] > 1

        attention_mask = hidden_states.new_zeros((batch_size, max_seq_len), dtype=ms.bool_)
        joint_hidden_states = hidden_states.new_zeros((batch_size, max_seq_len, self.config["hidden_size"]))
        attention_mask_tmp = []
        joint_hidden_states_tmp = []
        # TODO: Rewrite it since implement above might call ops.ScatterNdUpdate which is super slow and cause RuntimeError!
        for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
            seq_len = seq_len.item()
            attention_mask_tmp.append(
                mint.cat(
                    (
                        mint.full(attention_mask[i : i + 1, :seq_len].shape, True, dtype=ms.bool_),
                        mint.split(
                            attention_mask[i : i + 1], (seq_len, attention_mask[i : i + 1].shape[1] - seq_len), dim=1
                        )[1],
                    ),
                    dim=1,
                )
            )
            joint_hidden_states_tmp.append(
                mint.cat(
                    (
                        encoder_hidden_states[i, :encoder_seq_len].unsqueeze(0),
                        hidden_states[i].unsqueeze(0),
                        joint_hidden_states[i : i + 1, seq_len:],
                    ),
                    dim=1,
                )
            )

        attention_mask = mint.cat(attention_mask_tmp, dim=0)
        joint_hidden_states = mint.cat(joint_hidden_states_tmp, dim=0)
        hidden_states = joint_hidden_states

        for layer in self.layers:
            if use_mask:
                hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
            else:
                hidden_states = layer(hidden_states, None, rotary_emb, temb)

        # 4. Output norm & projection
        hidden_states = self.norm_out(hidden_states, temb)

        # 5. Unpatchify
        p = self.config["patch_size"]
        output = []
        for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
            output.append(
                hidden_states[i][encoder_seq_len:seq_len]
                .view(height // p, width // p, p, p, self.out_channels)
                .permute(4, 0, 2, 1, 3)
                .flatten(3, 4)
                .flatten(1, 2)
            )
        output = mint.stack(output, dim=0)

        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput dataclass

Bases: BaseOutput

The output of [Transformer2DModel].

PARAMETER DESCRIPTION
`(batch

The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability distributions for the unnoised latent pixels.

TYPE: size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete

Source code in mindone/diffusers/models/modeling_outputs.py
22
23
24
25
26
27
28
29
30
31
32
33
34
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or
        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: "ms.Tensor"  # noqa: F821