Skip to content

HunyuanDiT2DModel

A Diffusion Transformer model for 2D data from Hunyuan-DiT.

mindone.diffusers.HunyuanDiT2DModel

Bases: ModelMixin, ConfigMixin

Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.

PARAMETER DESCRIPTION
num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, *optional*, defaults to 16 DEFAULT: 16

attention_head_dim

The number of channels in each head.

TYPE: `int`, *optional*, defaults to 88 DEFAULT: 88

in_channels

The number of channels in the input and output (specify if the input is continuous).

TYPE: `int`, *optional* DEFAULT: None

patch_size

The size of the patch to use for the input.

TYPE: `int`, *optional* DEFAULT: None

activation_fn

Activation function to use in feed-forward.

TYPE: `str`, *optional*, defaults to `"geglu"` DEFAULT: 'gelu-approximate'

sample_size

The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings.

TYPE: `int`, *optional* DEFAULT: 32

dropout

The dropout probability to use.

TYPE: `float`, *optional*, defaults to 0.0

cross_attention_dim

The number of dimension in the clip text embedding.

TYPE: `int`, *optional* DEFAULT: 1024

hidden_size

The size of hidden layer in the conditioning embedding layers.

TYPE: `int`, *optional* DEFAULT: 1152

num_layers

The number of layers of Transformer blocks to use.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 28

mlp_ratio

The ratio of the hidden layer size to the input size.

TYPE: `float`, *optional*, defaults to 4.0 DEFAULT: 4.0

learn_sigma

Whether to predict variance.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: True

cross_attention_dim_t5

The number dimensions in t5 text embedding.

TYPE: `int`, *optional* DEFAULT: 2048

pooled_projection_dim

The size of the pooled projection.

TYPE: `int`, *optional* DEFAULT: 1024

text_len

The length of the clip text embedding.

TYPE: `int`, *optional* DEFAULT: 77

text_len_t5

The length of the T5 text embedding.

TYPE: `int`, *optional* DEFAULT: 256

use_style_cond_and_image_meta_size

Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2

TYPE: `bool`, *optional* DEFAULT: True

Source code in mindone/diffusers/models/transformers/hunyuan_transformer_2d.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
    """
    HunYuanDiT: Diffusion model with a Transformer backbone.

    Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16):
            The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88):
            The number of channels in each head.
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        patch_size (`int`, *optional*):
            The size of the patch to use for the input.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to use in feed-forward.
        sample_size (`int`, *optional*):
            The width of the latent images. This is fixed during training since it is used to learn a number of
            position embeddings.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
        cross_attention_dim (`int`, *optional*):
            The number of dimension in the clip text embedding.
        hidden_size (`int`, *optional*):
            The size of hidden layer in the conditioning embedding layers.
        num_layers (`int`, *optional*, defaults to 1):
            The number of layers of Transformer blocks to use.
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            The ratio of the hidden layer size to the input size.
        learn_sigma (`bool`, *optional*, defaults to `True`):
             Whether to predict variance.
        cross_attention_dim_t5 (`int`, *optional*):
            The number dimensions in t5 text embedding.
        pooled_projection_dim (`int`, *optional*):
            The size of the pooled projection.
        text_len (`int`, *optional*):
            The length of the clip text embedding.
        text_len_t5 (`int`, *optional*):
            The length of the T5 text embedding.
        use_style_cond_and_image_meta_size (`bool`,  *optional*):
            Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
    """

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        patch_size: Optional[int] = None,
        activation_fn: str = "gelu-approximate",
        sample_size=32,
        hidden_size=1152,
        num_layers: int = 28,
        mlp_ratio: float = 4.0,
        learn_sigma: bool = True,
        cross_attention_dim: int = 1024,
        norm_type: str = "layer_norm",
        cross_attention_dim_t5: int = 2048,
        pooled_projection_dim: int = 1024,
        text_len: int = 77,
        text_len_t5: int = 256,
        use_style_cond_and_image_meta_size: bool = True,
    ):
        super().__init__()
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.num_heads = num_attention_heads
        self.inner_dim = num_attention_heads * attention_head_dim

        self.text_embedder = PixArtAlphaTextProjection(
            in_features=cross_attention_dim_t5,
            hidden_size=cross_attention_dim_t5 * 4,
            out_features=cross_attention_dim,
            act_fn="silu_fp32",
        )

        self.text_embedding_padding = ms.Parameter(
            ops.randn(text_len + text_len_t5, cross_attention_dim, dtype=ms.float32),
            name="text_embedding_padding",
        )

        self.pos_embed = PatchEmbed(
            height=sample_size,
            width=sample_size,
            in_channels=in_channels,
            embed_dim=hidden_size,
            patch_size=patch_size,
            pos_embed_type=None,
        )

        self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
            hidden_size,
            pooled_projection_dim=pooled_projection_dim,
            seq_len=text_len_t5,
            cross_attention_dim=cross_attention_dim_t5,
            use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
        )

        # HunyuanDiT Blocks
        self.blocks = nn.CellList(
            [
                HunyuanDiTBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    activation_fn=activation_fn,
                    ff_inner_dim=int(self.inner_dim * mlp_ratio),
                    cross_attention_dim=cross_attention_dim,
                    qk_norm=True,  # See http://arxiv.org/abs/2302.05442 for details.
                    skip=layer > num_layers // 2,
                )
                for layer in range(num_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Dense(self.inner_dim, patch_size * patch_size * self.out_channels, has_bias=True)

    @property
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            for sub_name, child in module.name_cells().items():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.name_cells().items():
            fn_recursive_add_processors(name, module, processors)

        return processors

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.name_cells().items():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.name_cells().items():
            fn_recursive_attn_processor(name, module, processor)

    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(HunyuanAttnProcessor2_0())

    def construct(
        self,
        hidden_states,
        timestep,
        encoder_hidden_states=None,
        text_embedding_mask=None,
        encoder_hidden_states_t5=None,
        text_embedding_mask_t5=None,
        image_meta_size=None,
        style=None,
        image_rotary_emb=None,
        controlnet_block_samples=None,
        return_dict=False,
    ):
        """
        The [`HunyuanDiT2DModel`] forward method.

        Args:
        hidden_states (`ms.Tensor` of shape `(batch size, dim, height, width)`):
            The input tensor.
        timestep ( `ms.Tensor`, *optional*):
            Used to indicate denoising step.
        encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
            Conditional embeddings for cross attention layer. This is the output of `BertModel`.
        text_embedding_mask: ms.Tensor
            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
            of `BertModel`.
        encoder_hidden_states_t5 ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
            Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
        text_embedding_mask_t5: ms.Tensor
            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
            of T5 Text Encoder.
        image_meta_size (ms.Tensor):
            Conditional embedding indicate the image sizes
        style: ms.Tensor:
            Conditional embedding indicate the style
        image_rotary_emb (`ms.Tensor`):
            The image rotary embeddings to apply on query and key tensors during attention calculation.
        return_dict: bool
            Whether to return a dictionary.
        """

        height, width = hidden_states.shape[-2:]

        hidden_states = self.pos_embed(hidden_states)

        temb = self.time_extra_emb(
            timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
        )  # [B, D]

        # text projection
        batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
        encoder_hidden_states_t5 = self.text_embedder(
            encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
        )
        encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)

        encoder_hidden_states = ops.cat([encoder_hidden_states, encoder_hidden_states_t5], axis=1)
        text_embedding_mask = ops.cat([text_embedding_mask, text_embedding_mask_t5], axis=-1)
        text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

        encoder_hidden_states = ops.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)

        skips = []
        for layer, block in enumerate(self.blocks):
            if layer > self.config["num_layers"] // 2:
                if controlnet_block_samples is not None:
                    skip = skips[-1] + controlnet_block_samples[-1]
                    controlnet_block_samples = controlnet_block_samples[:-1]
                else:
                    skip = skips[-1]
                skips = skips[:-1]

                hidden_states = block(
                    hidden_states,
                    temb=temb,
                    encoder_hidden_states=encoder_hidden_states,
                    image_rotary_emb=image_rotary_emb,
                    skip=skip,
                )  # (N, L, D)
            else:
                hidden_states = block(
                    hidden_states,
                    temb=temb,
                    encoder_hidden_states=encoder_hidden_states,
                    image_rotary_emb=image_rotary_emb,
                )  # (N, L, D)

            if layer < (self.config["num_layers"] // 2 - 1):
                skips.append(hidden_states)

        if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
            raise ValueError("The number of controls is not equal to the number of skip connections.")

        # final layer
        hidden_states = self.norm_out(hidden_states, temb.to(ms.float32))
        hidden_states = self.proj_out(hidden_states)
        # (N, L, patch_size ** 2 * out_channels)

        # unpatchify: (N, out_channels, H, W)
        patch_size = self.pos_embed.patch_size
        height = height // patch_size
        width = width // patch_size

        hidden_states = hidden_states.reshape(
            hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels
        )
        # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states)
        hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
        output = hidden_states.reshape(
            hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size
        )
        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        Sets the attention processor to use [feed forward
        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

        Parameters:
            chunk_size (`int`, *optional*):
                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
                over each tensor of dim=`dim`.
            dim (`int`, *optional*, defaults to `0`):
                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
                or dim=1 (sequence length).
        """
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # By default chunk size is 1
        chunk_size = chunk_size or 1

        def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.name_cells().values():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.name_cells().values():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
    def disable_forward_chunking(self):
        def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.name_cells().values():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.name_cells().values():
            fn_recursive_feed_forward(module, None, 0)

mindone.diffusers.HunyuanDiT2DModel.attn_processors: Dict[str, AttentionProcessor] property

RETURNS DESCRIPTION
Dict[str, AttentionProcessor]

dict of attention processors: A dictionary containing all attention processors used in the model with

Dict[str, AttentionProcessor]

indexed by its weight name.

mindone.diffusers.HunyuanDiT2DModel.construct(hidden_states, timestep, encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, image_rotary_emb=None, controlnet_block_samples=None, return_dict=False)

The [HunyuanDiT2DModel] forward method.

hidden_states (ms.Tensor of shape (batch size, dim, height, width)): The input tensor. timestep ( ms.Tensor, optional): Used to indicate denoising step. encoder_hidden_states ( ms.Tensor of shape (batch size, sequence len, embed dims), optional): Conditional embeddings for cross attention layer. This is the output of BertModel. text_embedding_mask: ms.Tensor An attention mask of shape (batch, key_tokens) is applied to encoder_hidden_states. This is the output of BertModel. encoder_hidden_states_t5 ( ms.Tensor of shape (batch size, sequence len, embed dims), optional): Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. text_embedding_mask_t5: ms.Tensor An attention mask of shape (batch, key_tokens) is applied to encoder_hidden_states. This is the output of T5 Text Encoder. image_meta_size (ms.Tensor): Conditional embedding indicate the image sizes style: ms.Tensor: Conditional embedding indicate the style image_rotary_emb (ms.Tensor): The image rotary embeddings to apply on query and key tensors during attention calculation. return_dict: bool Whether to return a dictionary.

Source code in mindone/diffusers/models/transformers/hunyuan_transformer_2d.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def construct(
    self,
    hidden_states,
    timestep,
    encoder_hidden_states=None,
    text_embedding_mask=None,
    encoder_hidden_states_t5=None,
    text_embedding_mask_t5=None,
    image_meta_size=None,
    style=None,
    image_rotary_emb=None,
    controlnet_block_samples=None,
    return_dict=False,
):
    """
    The [`HunyuanDiT2DModel`] forward method.

    Args:
    hidden_states (`ms.Tensor` of shape `(batch size, dim, height, width)`):
        The input tensor.
    timestep ( `ms.Tensor`, *optional*):
        Used to indicate denoising step.
    encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
        Conditional embeddings for cross attention layer. This is the output of `BertModel`.
    text_embedding_mask: ms.Tensor
        An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
        of `BertModel`.
    encoder_hidden_states_t5 ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
        Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
    text_embedding_mask_t5: ms.Tensor
        An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
        of T5 Text Encoder.
    image_meta_size (ms.Tensor):
        Conditional embedding indicate the image sizes
    style: ms.Tensor:
        Conditional embedding indicate the style
    image_rotary_emb (`ms.Tensor`):
        The image rotary embeddings to apply on query and key tensors during attention calculation.
    return_dict: bool
        Whether to return a dictionary.
    """

    height, width = hidden_states.shape[-2:]

    hidden_states = self.pos_embed(hidden_states)

    temb = self.time_extra_emb(
        timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
    )  # [B, D]

    # text projection
    batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
    encoder_hidden_states_t5 = self.text_embedder(
        encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
    )
    encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)

    encoder_hidden_states = ops.cat([encoder_hidden_states, encoder_hidden_states_t5], axis=1)
    text_embedding_mask = ops.cat([text_embedding_mask, text_embedding_mask_t5], axis=-1)
    text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

    encoder_hidden_states = ops.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)

    skips = []
    for layer, block in enumerate(self.blocks):
        if layer > self.config["num_layers"] // 2:
            if controlnet_block_samples is not None:
                skip = skips[-1] + controlnet_block_samples[-1]
                controlnet_block_samples = controlnet_block_samples[:-1]
            else:
                skip = skips[-1]
            skips = skips[:-1]

            hidden_states = block(
                hidden_states,
                temb=temb,
                encoder_hidden_states=encoder_hidden_states,
                image_rotary_emb=image_rotary_emb,
                skip=skip,
            )  # (N, L, D)
        else:
            hidden_states = block(
                hidden_states,
                temb=temb,
                encoder_hidden_states=encoder_hidden_states,
                image_rotary_emb=image_rotary_emb,
            )  # (N, L, D)

        if layer < (self.config["num_layers"] // 2 - 1):
            skips.append(hidden_states)

    if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
        raise ValueError("The number of controls is not equal to the number of skip connections.")

    # final layer
    hidden_states = self.norm_out(hidden_states, temb.to(ms.float32))
    hidden_states = self.proj_out(hidden_states)
    # (N, L, patch_size ** 2 * out_channels)

    # unpatchify: (N, out_channels, H, W)
    patch_size = self.pos_embed.patch_size
    height = height // patch_size
    width = width // patch_size

    hidden_states = hidden_states.reshape(
        hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels
    )
    # hidden_states = ops.einsum("nhwpqc->nchpwq", hidden_states)
    hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4)
    output = hidden_states.reshape(
        hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size
    )
    if not return_dict:
        return (output,)
    return Transformer2DModelOutput(sample=output)

mindone.diffusers.HunyuanDiT2DModel.enable_forward_chunking(chunk_size=None, dim=0)

Sets the attention processor to use feed forward chunking.

PARAMETER DESCRIPTION
chunk_size

The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=dim.

TYPE: `int`, *optional* DEFAULT: None

dim

The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length).

TYPE: `int`, *optional*, defaults to `0` DEFAULT: 0

Source code in mindone/diffusers/models/transformers/hunyuan_transformer_2d.py
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
    """
    Sets the attention processor to use [feed forward
    chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

    Parameters:
        chunk_size (`int`, *optional*):
            The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
            over each tensor of dim=`dim`.
        dim (`int`, *optional*, defaults to `0`):
            The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
            or dim=1 (sequence length).
    """
    if dim not in [0, 1]:
        raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

    # By default chunk size is 1
    chunk_size = chunk_size or 1

    def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int):
        if hasattr(module, "set_chunk_feed_forward"):
            module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

        for child in module.name_cells().values():
            fn_recursive_feed_forward(child, chunk_size, dim)

    for module in self.name_cells().values():
        fn_recursive_feed_forward(module, chunk_size, dim)

mindone.diffusers.HunyuanDiT2DModel.set_attn_processor(processor)

Sets the attention processor to use to compute attention.

PARAMETER DESCRIPTION
processor

The instantiated processor class or a dictionary of processor classes that will be set as the processor for all Attention layers.

If processor is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.

TYPE: `dict` of `AttentionProcessor` or only `AttentionProcessor`

Source code in mindone/diffusers/models/transformers/hunyuan_transformer_2d.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
    r"""
    Sets the attention processor to use to compute attention.

    Parameters:
        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
            The instantiated processor class or a dictionary of processor classes that will be set as the processor
            for **all** `Attention` layers.

            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
            processor. This is strongly recommended when setting trainable attention processors.

    """
    count = len(self.attn_processors.keys())

    if isinstance(processor, dict) and len(processor) != count:
        raise ValueError(
            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
        )

    def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
        if hasattr(module, "set_processor"):
            if not isinstance(processor, dict):
                module.set_processor(processor)
            else:
                module.set_processor(processor.pop(f"{name}.processor"))

        for sub_name, child in module.name_cells().items():
            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

    for name, module in self.name_cells().items():
        fn_recursive_attn_processor(name, module, processor)

mindone.diffusers.HunyuanDiT2DModel.set_default_attn_processor()

Disables custom attention processors and sets the default attention implementation.

Source code in mindone/diffusers/models/transformers/hunyuan_transformer_2d.py
376
377
378
379
380
def set_default_attn_processor(self):
    """
    Disables custom attention processors and sets the default attention implementation.
    """
    self.set_attn_processor(HunyuanAttnProcessor2_0())