Skip to content

LuminaNextDiT2DModel

A Next Version of Diffusion Transformer model for 2D data from Lumina-T2X.

mindone.diffusers.LuminaNextDiT2DModel

Bases: ModelMixin, ConfigMixin

Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.

PARAMETER DESCRIPTION
sample_size

The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings.

TYPE: `int` DEFAULT: 128

patch_size

The size of each patch in the image. This parameter defines the resolution of patches fed into the model.

TYPE: `int`, *optional*, (`int`, *optional*, defaults to 2 DEFAULT: 2

in_channels

The number of input channels for the model. Typically, this matches the number of channels in the input images.

TYPE: `int`, *optional*, defaults to 4 DEFAULT: 4

hidden_size

The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations.

TYPE: `int`, *optional*, defaults to 4096 DEFAULT: 2304

num_layers

The number of layers in the model. This defines the depth of the neural network.

TYPE: `int`, *optional*, default to 32 DEFAULT: 32

num_attention_heads

The number of attention heads in each attention layer. This parameter specifies how many separate attention mechanisms are used.

TYPE: `int`, *optional*, defaults to 32 DEFAULT: 32

num_kv_heads

The number of key-value heads in the attention mechanism, if different from the number of attention heads. If None, it defaults to num_attention_heads.

TYPE: `int`, *optional*, defaults to 8 DEFAULT: None

multiple_of

A factor that the hidden size should be a multiple of. This can help optimize certain hardware configurations.

TYPE: `int`, *optional*, defaults to 256 DEFAULT: 256

ffn_dim_multiplier

A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on the model configuration.

TYPE: `float`, *optional* DEFAULT: None

norm_eps

A small value added to the denominator for numerical stability in normalization layers.

TYPE: `float`, *optional*, defaults to 1e-5 DEFAULT: 1e-05

learn_sigma

Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in predictions.

TYPE: `bool`, *optional*, defaults to True DEFAULT: True

qk_norm

Indicates if the queries and keys in the attention mechanism should be normalized.

TYPE: `bool`, *optional*, defaults to True DEFAULT: True

cross_attention_dim

The dimensionality of the text embeddings. This parameter defines the size of the text representations used in the model.

TYPE: `int`, *optional*, defaults to 2048 DEFAULT: 2048

scaling_factor

A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the overall scale of the model's operations.

TYPE: `float`, *optional*, defaults to 1.0 DEFAULT: 1.0

Source code in mindone/diffusers/models/transformers/lumina_nextdit2d.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
    """
    LuminaNextDiT: Diffusion model with a Transformer backbone.

    Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.

    Parameters:
        sample_size (`int`): The width of the latent images. This is fixed during training since
            it is used to learn a number of position embeddings.
        patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
            The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
        in_channels (`int`, *optional*, defaults to 4):
            The number of input channels for the model. Typically, this matches the number of channels in the input
            images.
        hidden_size (`int`, *optional*, defaults to 4096):
            The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
            hidden representations.
        num_layers (`int`, *optional*, default to 32):
            The number of layers in the model. This defines the depth of the neural network.
        num_attention_heads (`int`, *optional*, defaults to 32):
            The number of attention heads in each attention layer. This parameter specifies how many separate attention
            mechanisms are used.
        num_kv_heads (`int`, *optional*, defaults to 8):
            The number of key-value heads in the attention mechanism, if different from the number of attention heads.
            If None, it defaults to num_attention_heads.
        multiple_of (`int`, *optional*, defaults to 256):
            A factor that the hidden size should be a multiple of. This can help optimize certain hardware
            configurations.
        ffn_dim_multiplier (`float`, *optional*):
            A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
            the model configuration.
        norm_eps (`float`, *optional*, defaults to 1e-5):
            A small value added to the denominator for numerical stability in normalization layers.
        learn_sigma (`bool`, *optional*, defaults to True):
            Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
            predictions.
        qk_norm (`bool`, *optional*, defaults to True):
            Indicates if the queries and keys in the attention mechanism should be normalized.
        cross_attention_dim (`int`, *optional*, defaults to 2048):
            The dimensionality of the text embeddings. This parameter defines the size of the text representations used
            in the model.
        scaling_factor (`float`, *optional*, defaults to 1.0):
            A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
            overall scale of the model's operations.
    """

    @register_to_config
    def __init__(
        self,
        sample_size: int = 128,
        patch_size: Optional[int] = 2,
        in_channels: Optional[int] = 4,
        hidden_size: Optional[int] = 2304,
        num_layers: Optional[int] = 32,
        num_attention_heads: Optional[int] = 32,
        num_kv_heads: Optional[int] = None,
        multiple_of: Optional[int] = 256,
        ffn_dim_multiplier: Optional[float] = None,
        norm_eps: Optional[float] = 1e-5,
        learn_sigma: Optional[bool] = True,
        qk_norm: Optional[bool] = True,
        cross_attention_dim: Optional[int] = 2048,
        scaling_factor: Optional[float] = 1.0,
    ) -> None:
        super().__init__()
        self.sample_size = sample_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        self.scaling_factor = scaling_factor

        self.patch_embedder = LuminaPatchEmbed(
            patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
        )

        self.pad_token = ms.Parameter(ops.zeros((hidden_size,)))

        self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
            hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
        )

        self.layers = nn.CellList(
            [
                LuminaNextDiTBlock(
                    hidden_size,
                    num_attention_heads,
                    num_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    qk_norm,
                    cross_attention_dim,
                )
                for _ in range(num_layers)
            ]
        )
        self.norm_out = LuminaLayerNormContinuous(
            embedding_dim=hidden_size,
            conditioning_embedding_dim=min(hidden_size, 1024),
            elementwise_affine=False,
            eps=1e-6,
            bias=True,
            out_dim=patch_size * patch_size * self.out_channels,
        )
        # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)

        assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: ms.Tensor,
        encoder_hidden_states: ms.Tensor,
        encoder_mask: ms.Tensor,
        image_rotary_emb: ms.Tensor,
        cross_attention_kwargs: Dict[str, Any] = None,
        return_dict=False,
    ) -> ms.Tensor:
        """
        Forward pass of LuminaNextDiT.

        Parameters:
            hidden_states (ms.Tensor): Input tensor of shape (N, C, H, W).
            timestep (ms.Tensor): Tensor of diffusion timesteps of shape (N,).
            encoder_hidden_states (ms.Tensor): Tensor of caption features of shape (N, D).
            encoder_mask (ms.Tensor): Tensor of caption masks of shape (N, L).
        """
        hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)

        temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)

        encoder_mask = encoder_mask.bool()
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                mask,
                image_rotary_emb,
                encoder_hidden_states,
                encoder_mask,
                temb=temb,
                cross_attention_kwargs=cross_attention_kwargs,
            )

        hidden_states = self.norm_out(hidden_states, temb)

        # unpatchify
        height_tokens = width_tokens = self.patch_size
        height, width = img_size[0]
        batch_size = hidden_states.shape[0]
        sequence_length = (height // height_tokens) * (width // width_tokens)
        hidden_states = hidden_states[:, :sequence_length].view(
            batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
        )
        output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(start_dim=4, end_dim=5).flatten(start_dim=2, end_dim=3)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

mindone.diffusers.LuminaNextDiT2DModel.construct(hidden_states, timestep, encoder_hidden_states, encoder_mask, image_rotary_emb, cross_attention_kwargs=None, return_dict=False)

Forward pass of LuminaNextDiT.

PARAMETER DESCRIPTION
hidden_states

Input tensor of shape (N, C, H, W).

TYPE: Tensor

timestep

Tensor of diffusion timesteps of shape (N,).

TYPE: Tensor

encoder_hidden_states

Tensor of caption features of shape (N, D).

TYPE: Tensor

encoder_mask

Tensor of caption masks of shape (N, L).

TYPE: Tensor

Source code in mindone/diffusers/models/transformers/lumina_nextdit2d.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def construct(
    self,
    hidden_states: ms.Tensor,
    timestep: ms.Tensor,
    encoder_hidden_states: ms.Tensor,
    encoder_mask: ms.Tensor,
    image_rotary_emb: ms.Tensor,
    cross_attention_kwargs: Dict[str, Any] = None,
    return_dict=False,
) -> ms.Tensor:
    """
    Forward pass of LuminaNextDiT.

    Parameters:
        hidden_states (ms.Tensor): Input tensor of shape (N, C, H, W).
        timestep (ms.Tensor): Tensor of diffusion timesteps of shape (N,).
        encoder_hidden_states (ms.Tensor): Tensor of caption features of shape (N, D).
        encoder_mask (ms.Tensor): Tensor of caption masks of shape (N, L).
    """
    hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)

    temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)

    encoder_mask = encoder_mask.bool()
    for layer in self.layers:
        hidden_states = layer(
            hidden_states,
            mask,
            image_rotary_emb,
            encoder_hidden_states,
            encoder_mask,
            temb=temb,
            cross_attention_kwargs=cross_attention_kwargs,
        )

    hidden_states = self.norm_out(hidden_states, temb)

    # unpatchify
    height_tokens = width_tokens = self.patch_size
    height, width = img_size[0]
    batch_size = hidden_states.shape[0]
    sequence_length = (height // height_tokens) * (width // width_tokens)
    hidden_states = hidden_states[:, :sequence_length].view(
        batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
    )
    output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(start_dim=4, end_dim=5).flatten(start_dim=2, end_dim=3)

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)