Skip to content

OmniGenTransformer2DModel

A Transformer model that accepts multimodal instructions to generate images for OmniGen.

The abstract from the paper is:

The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.

import mindspore
from mindone.diffusers import OmniGenTransformer2DModel

transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", mindspore_dtype=mindspore.bfloat16)

mindone.diffusers.OmniGenTransformer2DModel

Bases: ModelMixin, ConfigMixin

The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).

PARAMETER DESCRIPTION
in_channels

The number of channels in the input.

TYPE: `int`, defaults to `4` DEFAULT: 4

patch_size

The size of the spatial patches to use in the patch embedding layer.

TYPE: `int`, defaults to `2` DEFAULT: 2

hidden_size

The dimensionality of the hidden layers in the model.

TYPE: `int`, defaults to `3072` DEFAULT: 3072

rms_norm_eps

Eps for RMSNorm layer.

TYPE: `float`, defaults to `1e-5` DEFAULT: 1e-05

num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, defaults to `32` DEFAULT: 32

num_key_value_heads

The number of heads to use for keys and values in multi-head attention.

TYPE: `int`, defaults to `32` DEFAULT: 32

intermediate_size

Dimension of the hidden layer in FeedForward layers.

TYPE: `int`, defaults to `8192` DEFAULT: 8192

num_layers

The number of layers of transformer blocks to use.

TYPE: `int`, default to `32` DEFAULT: 32

pad_token_id

The id of the padding token.

TYPE: `int`, default to `32000` DEFAULT: 32000

vocab_size

The size of the vocabulary of the embedding vocabulary.

TYPE: `int`, default to `32064` DEFAULT: 32064

rope_base

The default theta value to use when creating RoPE.

TYPE: `int`, default to `10000` DEFAULT: 10000

rope_scaling

The scaling factors for the RoPE. Must contain short_factor and long_factor.

TYPE: `Dict` DEFAULT: None

pos_embed_max_size

The maximum size of the positional embeddings.

TYPE: `int`, default to `192` DEFAULT: 192

time_step_dim

Output dimension of timestep embeddings.

TYPE: `int`, default to `256` DEFAULT: 256

flip_sin_to_cos

Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.

TYPE: `bool`, default to `True` DEFAULT: True

downscale_freq_shift

The frequency shift to use when downscaling the timestep embeddings.

TYPE: `int`, default to `0` DEFAULT: 0

timestep_activation_fn

The activation function to use for the timestep embeddings.

TYPE: `str`, default to `silu` DEFAULT: 'silu'

Source code in mindone/diffusers/models/transformers/transformer_omnigen.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
    """
    The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).

    Parameters:
        in_channels (`int`, defaults to `4`):
            The number of channels in the input.
        patch_size (`int`, defaults to `2`):
            The size of the spatial patches to use in the patch embedding layer.
        hidden_size (`int`, defaults to `3072`):
            The dimensionality of the hidden layers in the model.
        rms_norm_eps (`float`, defaults to `1e-5`):
            Eps for RMSNorm layer.
        num_attention_heads (`int`, defaults to `32`):
            The number of heads to use for multi-head attention.
        num_key_value_heads (`int`, defaults to `32`):
            The number of heads to use for keys and values in multi-head attention.
        intermediate_size (`int`, defaults to `8192`):
            Dimension of the hidden layer in FeedForward layers.
        num_layers (`int`, default to `32`):
            The number of layers of transformer blocks to use.
        pad_token_id (`int`, default to `32000`):
            The id of the padding token.
        vocab_size (`int`, default to `32064`):
            The size of the vocabulary of the embedding vocabulary.
        rope_base (`int`, default to `10000`):
            The default theta value to use when creating RoPE.
        rope_scaling (`Dict`, optional):
            The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
        pos_embed_max_size (`int`, default to `192`):
            The maximum size of the positional embeddings.
        time_step_dim (`int`, default to `256`):
            Output dimension of timestep embeddings.
        flip_sin_to_cos (`bool`, default to `True`):
            Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
        downscale_freq_shift (`int`, default to `0`):
            The frequency shift to use when downscaling the timestep embeddings.
        timestep_activation_fn (`str`, default to `silu`):
            The activation function to use for the timestep embeddings.
    """

    _supports_gradient_checkpointing = True
    _no_split_modules = ["OmniGenBlock"]
    _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]

    @register_to_config
    def __init__(
        self,
        in_channels: int = 4,
        patch_size: int = 2,
        hidden_size: int = 3072,
        rms_norm_eps: float = 1e-5,
        num_attention_heads: int = 32,
        num_key_value_heads: int = 32,
        intermediate_size: int = 8192,
        num_layers: int = 32,
        pad_token_id: int = 32000,
        vocab_size: int = 32064,
        max_position_embeddings: int = 131072,
        original_max_position_embeddings: int = 4096,
        rope_base: int = 10000,
        rope_scaling: Dict = None,
        pos_embed_max_size: int = 192,
        time_step_dim: int = 256,
        flip_sin_to_cos: bool = True,
        downscale_freq_shift: int = 0,
        timestep_activation_fn: str = "silu",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels

        self.patch_embedding = OmniGenPatchEmbed(
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=hidden_size,
            pos_embed_max_size=pos_embed_max_size,
        )

        self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
        self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
        self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)

        self.embed_tokens = mint.nn.Embedding(vocab_size, hidden_size, pad_token_id)
        self.rope = OmniGenSuScaledRotaryEmbedding(
            hidden_size // num_attention_heads,
            max_position_embeddings=max_position_embeddings,
            original_max_position_embeddings=original_max_position_embeddings,
            base=rope_base,
            rope_scaling=rope_scaling,
        )

        self.layers = nn.CellList(
            [
                OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
                for _ in range(num_layers)
            ]
        )

        self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
        self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
        self.proj_out = nn.Dense(hidden_size, patch_size * patch_size * self.out_channels, has_bias=True)
        self.p = self.config.patch_size

        self.gradient_checkpointing = False

    def _get_multimodal_embeddings(
        self, input_ids: ms.Tensor, input_img_latents: List[ms.Tensor], input_image_sizes: Dict
    ) -> Optional[ms.Tensor]:
        if input_ids is None:
            return None

        input_img_latents = [x.to(self.dtype) for x in input_img_latents]
        condition_tokens = self.embed_tokens(input_ids)
        input_img_inx = 0
        input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
        for b_inx in input_image_sizes.keys():
            for start_inx, end_inx in input_image_sizes[b_inx]:
                # replace the placeholder in text tokens with the image embedding.
                # TODO tensor index setitem will support value broadcast at mindspore 2.7
                condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx][0].to(
                    condition_tokens.dtype
                )
                input_img_inx += 1
        return condition_tokens

    def construct(
        self,
        hidden_states: ms.Tensor,
        timestep: Union[int, float, ms.Tensor],
        input_ids: ms.Tensor,
        input_img_latents: List[ms.Tensor],
        input_image_sizes: Dict[int, List[int]],
        attention_mask: ms.Tensor,
        position_ids: ms.Tensor,
        return_dict: bool = False,
    ) -> Union[Transformer2DModelOutput, Tuple[ms.Tensor]]:
        batch_size, num_channels, height, width = hidden_states.shape
        p = self.p
        post_patch_height, post_patch_width = height // p, width // p

        # 1. Patch & Timestep & Conditional Embedding
        hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
        num_tokens_for_output_image = hidden_states.shape[1]

        timestep_proj = self.time_proj(timestep).type_as(hidden_states)
        time_token = self.time_token(timestep_proj).unsqueeze(1)
        temb = self.t_embedder(timestep_proj)

        condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
        if condition_tokens is not None:
            hidden_states = mint.cat([condition_tokens, time_token, hidden_states], dim=1)
        else:
            hidden_states = mint.cat([time_token, hidden_states], dim=1)

        seq_length = hidden_states.shape[1]
        position_ids = position_ids.view(-1, seq_length).long()

        # 2. Attention mask preprocessing
        if attention_mask is not None and attention_mask.dim() == 3:
            dtype = hidden_states.dtype
            min_dtype = dtype_to_min(dtype)
            attention_mask = (1 - attention_mask) * min_dtype
            attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)

        # 3. Rotary position embedding
        image_rotary_emb = self.rope(hidden_states, position_ids)

        # 4. Transformer blocks
        for block in self.layers:
            hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)

        # 5. Output norm & projection
        hidden_states = self.norm(hidden_states)
        hidden_states = hidden_states[:, -num_tokens_for_output_image:]
        hidden_states = self.norm_out(hidden_states, temb=temb)
        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
        output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)

        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)