Skip to content

HunyuanDiT2DControlNetModel

HunyuanDiT2DControlNetModel is an implementation of ControlNet for Hunyuan-DiT.

ControlNet was introduced in Adding Conditional Control to Text-to-Image Diffusion Models by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.

With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.

The abstract from the paper is:

We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.

This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on Tencent Hunyuan.

Example For Loading HunyuanDiT2DControlNetModel

from mindone.diffusers import HunyuanDiT2DControlNetModel
import mindspore as ms
controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", mindspore_dtype=ms.float16)

mindone.diffusers.HunyuanDiT2DControlNetModel

Bases: ModelMixin, ConfigMixin

Source code in mindone/diffusers/models/controlnet_hunyuan.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        conditioning_channels: int = 3,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        patch_size: Optional[int] = None,
        activation_fn: str = "gelu-approximate",
        sample_size=32,
        hidden_size=1152,
        transformer_num_layers: int = 40,
        mlp_ratio: float = 4.0,
        cross_attention_dim: int = 1024,
        cross_attention_dim_t5: int = 2048,
        pooled_projection_dim: int = 1024,
        text_len: int = 77,
        text_len_t5: int = 256,
        use_style_cond_and_image_meta_size: bool = True,
    ):
        super().__init__()
        self.num_heads = num_attention_heads
        self.inner_dim = num_attention_heads * attention_head_dim

        self.text_embedder = PixArtAlphaTextProjection(
            in_features=cross_attention_dim_t5,
            hidden_size=cross_attention_dim_t5 * 4,
            out_features=cross_attention_dim,
            act_fn="silu_fp32",
        )

        self.text_embedding_padding = ms.Parameter(
            ops.randn((text_len + text_len_t5, cross_attention_dim), dtype=ms.float32)
        )

        self.pos_embed = PatchEmbed(
            height=sample_size,
            width=sample_size,
            in_channels=in_channels,
            embed_dim=hidden_size,
            patch_size=patch_size,
            pos_embed_type=None,
        )

        self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
            hidden_size,
            pooled_projection_dim=pooled_projection_dim,
            seq_len=text_len_t5,
            cross_attention_dim=cross_attention_dim_t5,
            use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
        )

        # controlnet_blocks
        controlnet_blocks = []

        # HunyuanDiT Blocks
        self.blocks = nn.CellList(
            [
                HunyuanDiTBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    activation_fn=activation_fn,
                    ff_inner_dim=int(self.inner_dim * mlp_ratio),
                    cross_attention_dim=cross_attention_dim,
                    qk_norm=True,  # See http://arxiv.org/abs/2302.05442 for details.
                    skip=False,  # always False as it is the first half of the model
                )
                for layer in range(transformer_num_layers // 2 - 1)
            ]
        )
        self.input_block = nn.Dense(hidden_size, hidden_size, weight_init="zeros", bias_init="zeros")
        for _ in range(len(self.blocks)):
            controlnet_block = nn.Dense(hidden_size, hidden_size, weight_init="zeros", bias_init="zeros")
            controlnet_blocks.append(controlnet_block)
        self.controlnet_blocks = nn.CellList(controlnet_blocks)

    @property
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)

            for sub_name, child in module.name_cells().items():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.name_cells().items():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
                corresponding cross attention processor. This is strongly recommended when setting trainable attention
                processors.
        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.name_cells().items():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.name_cells().items():
            fn_recursive_attn_processor(name, module, processor)

    @classmethod
    def from_transformer(
        cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
    ):
        config = transformer.config
        activation_fn = config.activation_fn
        attention_head_dim = config.attention_head_dim
        cross_attention_dim = config.cross_attention_dim
        cross_attention_dim_t5 = config.cross_attention_dim_t5
        hidden_size = config.hidden_size
        in_channels = config.in_channels
        mlp_ratio = config.mlp_ratio
        num_attention_heads = config.num_attention_heads
        patch_size = config.patch_size
        sample_size = config.sample_size
        text_len = config.text_len
        text_len_t5 = config.text_len_t5

        conditioning_channels = conditioning_channels
        transformer_num_layers = transformer_num_layers or config.transformer_num_layers

        controlnet = cls(
            conditioning_channels=conditioning_channels,
            transformer_num_layers=transformer_num_layers,
            activation_fn=activation_fn,
            attention_head_dim=attention_head_dim,
            cross_attention_dim=cross_attention_dim,
            cross_attention_dim_t5=cross_attention_dim_t5,
            hidden_size=hidden_size,
            in_channels=in_channels,
            mlp_ratio=mlp_ratio,
            num_attention_heads=num_attention_heads,
            patch_size=patch_size,
            sample_size=sample_size,
            text_len=text_len,
            text_len_t5=text_len_t5,
        )
        if load_weights_from_transformer:
            key = ms.load_param_into_net(controlnet, transformer.parameters_dict(), strict_load=True)
            logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
        return controlnet

    def construct(
        self,
        hidden_states,
        timestep,
        controlnet_cond: ms.Tensor,
        conditioning_scale: float = 1.0,
        encoder_hidden_states=None,
        text_embedding_mask=None,
        encoder_hidden_states_t5=None,
        text_embedding_mask_t5=None,
        image_meta_size=None,
        style=None,
        image_rotary_emb=None,
        return_dict=False,
    ):
        """
        The [`HunyuanDiT2DControlNetModel`] forward method.

        Args:
        hidden_states (`ms.Tensor` of shape `(batch size, dim, height, width)`):
            The input tensor.
        timestep ( `ms.Tensor`, *optional*):
            Used to indicate denoising step.
        controlnet_cond ( `ms.Tensor` ):
            The conditioning input to ControlNet.
        conditioning_scale ( `float` ):
            Indicate the conditioning scale.
        encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
            Conditional embeddings for cross attention layer. This is the output of `BertModel`.
        text_embedding_mask: ms.Tensor
            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
            of `BertModel`.
        encoder_hidden_states_t5 ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
            Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
        text_embedding_mask_t5: ms.Tensor
            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
            of T5 Text Encoder.
        image_meta_size (ms.Tensor):
            Conditional embedding indicate the image sizes
        style: ms.Tensor:
            Conditional embedding indicate the style
        image_rotary_emb (`ms.Tensor`):
            The image rotary embeddings to apply on query and key tensors during attention calculation.
        return_dict: bool
            Whether to return a dictionary.
        """

        height, width = hidden_states.shape[-2:]

        hidden_states = self.pos_embed(hidden_states)  # b,c,H,W -> b, N, C

        # 2. pre-process
        hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))

        temb = self.time_extra_emb(
            timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
        )  # [B, D]

        # text projection
        batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
        encoder_hidden_states_t5 = self.text_embedder(
            encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
        )
        encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)

        encoder_hidden_states = ops.cat([encoder_hidden_states, encoder_hidden_states_t5], axis=1)
        text_embedding_mask = ops.cat([text_embedding_mask, text_embedding_mask_t5], axis=-1)
        text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

        encoder_hidden_states = ops.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)

        block_res_samples = ()
        for layer, block in enumerate(self.blocks):
            hidden_states = block(
                hidden_states,
                temb=temb,
                encoder_hidden_states=encoder_hidden_states,
                image_rotary_emb=image_rotary_emb,
            )  # (N, L, D)

            block_res_samples = block_res_samples + (hidden_states,)

        controlnet_block_res_samples = ()
        for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
            block_res_sample = controlnet_block(block_res_sample)
            controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)

        # 6. scaling
        controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]

        if not return_dict:
            return (controlnet_block_res_samples,)

        return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)

mindone.diffusers.HunyuanDiT2DControlNetModel.attn_processors: Dict[str, AttentionProcessor] property

RETURNS DESCRIPTION
Dict[str, AttentionProcessor]

dict of attention processors: A dictionary containing all attention processors used in the model with

Dict[str, AttentionProcessor]

indexed by its weight name.

mindone.diffusers.HunyuanDiT2DControlNetModel.construct(hidden_states, timestep, controlnet_cond, conditioning_scale=1.0, encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, image_rotary_emb=None, return_dict=False)

The [HunyuanDiT2DControlNetModel] forward method.

hidden_states (ms.Tensor of shape (batch size, dim, height, width)): The input tensor. timestep ( ms.Tensor, optional): Used to indicate denoising step. controlnet_cond ( ms.Tensor ): The conditioning input to ControlNet. conditioning_scale ( float ): Indicate the conditioning scale. encoder_hidden_states ( ms.Tensor of shape (batch size, sequence len, embed dims), optional): Conditional embeddings for cross attention layer. This is the output of BertModel. text_embedding_mask: ms.Tensor An attention mask of shape (batch, key_tokens) is applied to encoder_hidden_states. This is the output of BertModel. encoder_hidden_states_t5 ( ms.Tensor of shape (batch size, sequence len, embed dims), optional): Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. text_embedding_mask_t5: ms.Tensor An attention mask of shape (batch, key_tokens) is applied to encoder_hidden_states. This is the output of T5 Text Encoder. image_meta_size (ms.Tensor): Conditional embedding indicate the image sizes style: ms.Tensor: Conditional embedding indicate the style image_rotary_emb (ms.Tensor): The image rotary embeddings to apply on query and key tensors during attention calculation. return_dict: bool Whether to return a dictionary.

Source code in mindone/diffusers/models/controlnet_hunyuan.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def construct(
    self,
    hidden_states,
    timestep,
    controlnet_cond: ms.Tensor,
    conditioning_scale: float = 1.0,
    encoder_hidden_states=None,
    text_embedding_mask=None,
    encoder_hidden_states_t5=None,
    text_embedding_mask_t5=None,
    image_meta_size=None,
    style=None,
    image_rotary_emb=None,
    return_dict=False,
):
    """
    The [`HunyuanDiT2DControlNetModel`] forward method.

    Args:
    hidden_states (`ms.Tensor` of shape `(batch size, dim, height, width)`):
        The input tensor.
    timestep ( `ms.Tensor`, *optional*):
        Used to indicate denoising step.
    controlnet_cond ( `ms.Tensor` ):
        The conditioning input to ControlNet.
    conditioning_scale ( `float` ):
        Indicate the conditioning scale.
    encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
        Conditional embeddings for cross attention layer. This is the output of `BertModel`.
    text_embedding_mask: ms.Tensor
        An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
        of `BertModel`.
    encoder_hidden_states_t5 ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
        Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
    text_embedding_mask_t5: ms.Tensor
        An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
        of T5 Text Encoder.
    image_meta_size (ms.Tensor):
        Conditional embedding indicate the image sizes
    style: ms.Tensor:
        Conditional embedding indicate the style
    image_rotary_emb (`ms.Tensor`):
        The image rotary embeddings to apply on query and key tensors during attention calculation.
    return_dict: bool
        Whether to return a dictionary.
    """

    height, width = hidden_states.shape[-2:]

    hidden_states = self.pos_embed(hidden_states)  # b,c,H,W -> b, N, C

    # 2. pre-process
    hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))

    temb = self.time_extra_emb(
        timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
    )  # [B, D]

    # text projection
    batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
    encoder_hidden_states_t5 = self.text_embedder(
        encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
    )
    encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)

    encoder_hidden_states = ops.cat([encoder_hidden_states, encoder_hidden_states_t5], axis=1)
    text_embedding_mask = ops.cat([text_embedding_mask, text_embedding_mask_t5], axis=-1)
    text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()

    encoder_hidden_states = ops.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)

    block_res_samples = ()
    for layer, block in enumerate(self.blocks):
        hidden_states = block(
            hidden_states,
            temb=temb,
            encoder_hidden_states=encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )  # (N, L, D)

        block_res_samples = block_res_samples + (hidden_states,)

    controlnet_block_res_samples = ()
    for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
        block_res_sample = controlnet_block(block_res_sample)
        controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)

    # 6. scaling
    controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]

    if not return_dict:
        return (controlnet_block_res_samples,)

    return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)

mindone.diffusers.HunyuanDiT2DControlNetModel.set_attn_processor(processor)

Sets the attention processor to use to compute attention.

PARAMETER DESCRIPTION
processor

The instantiated processor class or a dictionary of processor classes that will be set as the processor for all Attention layers. If processor is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.

TYPE: `dict` of `AttentionProcessor` or only `AttentionProcessor`

Source code in mindone/diffusers/models/controlnet_hunyuan.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
    r"""
    Sets the attention processor to use to compute attention.

    Parameters:
        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
            The instantiated processor class or a dictionary of processor classes that will be set as the processor
            for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
            corresponding cross attention processor. This is strongly recommended when setting trainable attention
            processors.
    """
    count = len(self.attn_processors.keys())

    if isinstance(processor, dict) and len(processor) != count:
        raise ValueError(
            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
        )

    def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
        if hasattr(module, "set_processor"):
            if not isinstance(processor, dict):
                module.set_processor(processor)
            else:
                module.set_processor(processor.pop(f"{name}.processor"))

        for sub_name, child in module.name_cells().items():
            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

    for name, module in self.name_cells().items():
        fn_recursive_attn_processor(name, module, processor)