Skip to content

TransformerTemporalModel

A Transformer model for video-like data.

mindone.diffusers.models.transformers.transformer_temporal.TransformerTemporalModel

Bases: ModelMixin, ConfigMixin

A Transformer model for video-like data.

PARAMETER DESCRIPTION
num_attention_heads

The number of heads to use for multi-head attention.

TYPE: `int`, *optional*, defaults to 16 DEFAULT: 16

attention_head_dim

The number of channels in each head.

TYPE: `int`, *optional*, defaults to 88 DEFAULT: 88

in_channels

The number of channels in the input and output (specify if the input is continuous).

TYPE: `int`, *optional* DEFAULT: None

num_layers

The number of layers of Transformer blocks to use.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

dropout

The dropout probability to use.

TYPE: `float`, *optional*, defaults to 0.0 DEFAULT: 0.0

cross_attention_dim

The number of encoder_hidden_states dimensions to use.

TYPE: `int`, *optional* DEFAULT: None

attention_bias

Configure if the TransformerBlock attention should contain a bias parameter.

TYPE: `bool`, *optional* DEFAULT: False

sample_size

The width of the latent images (specify if the input is discrete). This is fixed during training since it is used to learn a number of position embeddings.

TYPE: `int`, *optional* DEFAULT: None

activation_fn

Activation function to use in feed-forward. See diffusers.models.activations.get_activation for supported activation functions.

TYPE: `str`, *optional*, defaults to `"geglu"` DEFAULT: 'geglu'

norm_elementwise_affine

Configure if the TransformerBlock should use learnable elementwise affine parameters for normalization.

TYPE: `bool`, *optional* DEFAULT: True

double_self_attention

Configure if each TransformerBlock should contain two self-attention layers.

TYPE: `bool`, *optional* DEFAULT: True

positional_embeddings

(str, optional): The type of positional embeddings to apply to the sequence input before passing use.

TYPE: Optional[str] DEFAULT: None

num_positional_embeddings

(int, optional): The maximum length of the sequence over which to apply positional embeddings.

TYPE: Optional[int] DEFAULT: None

Source code in mindone/diffusers/models/transformers/transformer_temporal.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class TransformerTemporalModel(ModelMixin, ConfigMixin):
    """
    A Transformer model for video-like data.

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        attention_bias (`bool`, *optional*):
            Configure if the `TransformerBlock` attention should contain a bias parameter.
        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
            This is fixed during training since it is used to learn a number of position embeddings.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
            activation functions.
        norm_elementwise_affine (`bool`, *optional*):
            Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
        double_self_attention (`bool`, *optional*):
            Configure if each `TransformerBlock` should contain two self-attention layers.
        positional_embeddings: (`str`, *optional*):
            The type of positional embeddings to apply to the sequence input before passing use.
        num_positional_embeddings: (`int`, *optional*):
            The maximum length of the sequence over which to apply positional embeddings.
    """

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        sample_size: Optional[int] = None,
        activation_fn: str = "geglu",
        norm_elementwise_affine: bool = True,
        double_self_attention: bool = True,
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim

        self.in_channels = in_channels

        self.norm = GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.proj_in = nn.Dense(in_channels, inner_dim)

        # 3. Define transformers blocks
        self.transformer_blocks = nn.CellList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    double_self_attention=double_self_attention,
                    norm_elementwise_affine=norm_elementwise_affine,
                    positional_embeddings=positional_embeddings,
                    num_positional_embeddings=num_positional_embeddings,
                )
                for d in range(num_layers)
            ]
        )

        self.proj_out = nn.Dense(inner_dim, in_channels)

    def construct(
        self,
        hidden_states: ms.Tensor,
        encoder_hidden_states: Optional[ms.Tensor] = None,
        timestep: Optional[ms.Tensor] = None,
        class_labels: ms.Tensor = None,
        num_frames: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = False,
    ) -> TransformerTemporalModelOutput:
        """
        The [`TransformerTemporal`] forward method.

        Args:
            hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete,
            `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous):
                Input hidden_states.
            encoder_hidden_states ( `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            timestep ( `ms.Tensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*):
                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
                `AdaLayerZeroNorm`.
            num_frames (`int`, *optional*, defaults to 1):
                The number of frames to be processed per batch. This is used to reshape the hidden states.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`]
                instead of a plain tuple.

        Returns:
            [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
                If `return_dict` is True, an
                [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a
                `tuple` where the first element is the sample tensor.
        """
        # 1. Input
        batch_frames, channel, height, width = hidden_states.shape
        batch_size = batch_frames // num_frames

        residual = hidden_states

        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4)

        hidden_states = self.norm(hidden_states)
        hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

        hidden_states = self.proj_in(hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                class_labels=class_labels,
            )

        # 3. Output
        hidden_states = self.proj_out(hidden_states)
        hidden_states = (
            hidden_states[None, None, :].reshape(batch_size, height, width, num_frames, channel).permute(0, 3, 4, 1, 2)
        )
        hidden_states = hidden_states.reshape(batch_frames, channel, height, width)

        output = hidden_states + residual

        if not return_dict:
            return (output,)

        return TransformerTemporalModelOutput(sample=output)

mindone.diffusers.models.transformers.transformer_temporal.TransformerTemporalModel.construct(hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, num_frames=1, cross_attention_kwargs=None, return_dict=False)

The [TransformerTemporal] forward method.

PARAMETER DESCRIPTION
`ms.Tensor`

Input hidden_states.

TYPE: of shape `(batch size, channel, height, width)` if continuous

encoder_hidden_states

Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention.

TYPE: `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional* DEFAULT: None

timestep

Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

class_labels

Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in AdaLayerZeroNorm.

TYPE: `ms.Tensor` of shape `(batch size, num classes)`, *optional* DEFAULT: None

num_frames

The number of frames to be processed per batch. This is used to reshape the hidden states.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

cross_attention_kwargs

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

TYPE: `dict`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~models.transformers.transformer_temporal.TransformerTemporalModelOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

RETURNS DESCRIPTION
TransformerTemporalModelOutput

[~models.transformers.transformer_temporal.TransformerTemporalModelOutput] or tuple: If return_dict is True, an [~models.transformers.transformer_temporal.TransformerTemporalModelOutput] is returned, otherwise a tuple where the first element is the sample tensor.

Source code in mindone/diffusers/models/transformers/transformer_temporal.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def construct(
    self,
    hidden_states: ms.Tensor,
    encoder_hidden_states: Optional[ms.Tensor] = None,
    timestep: Optional[ms.Tensor] = None,
    class_labels: ms.Tensor = None,
    num_frames: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = False,
) -> TransformerTemporalModelOutput:
    """
    The [`TransformerTemporal`] forward method.

    Args:
        hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete,
        `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous):
            Input hidden_states.
        encoder_hidden_states ( `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
            self-attention.
        timestep ( `ms.Tensor`, *optional*):
            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
        class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*):
            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
            `AdaLayerZeroNorm`.
        num_frames (`int`, *optional*, defaults to 1):
            The number of frames to be processed per batch. This is used to reshape the hidden states.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`]
            instead of a plain tuple.

    Returns:
        [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
            If `return_dict` is True, an
            [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
    """
    # 1. Input
    batch_frames, channel, height, width = hidden_states.shape
    batch_size = batch_frames // num_frames

    residual = hidden_states

    hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
    hidden_states = hidden_states.permute(0, 2, 1, 3, 4)

    hidden_states = self.norm(hidden_states)
    hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

    hidden_states = self.proj_in(hidden_states)

    # 2. Blocks
    for block in self.transformer_blocks:
        hidden_states = block(
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timestep,
            cross_attention_kwargs=cross_attention_kwargs,
            class_labels=class_labels,
        )

    # 3. Output
    hidden_states = self.proj_out(hidden_states)
    hidden_states = (
        hidden_states[None, None, :].reshape(batch_size, height, width, num_frames, channel).permute(0, 3, 4, 1, 2)
    )
    hidden_states = hidden_states.reshape(batch_frames, channel, height, width)

    output = hidden_states + residual

    if not return_dict:
        return (output,)

    return TransformerTemporalModelOutput(sample=output)

mindone.diffusers.models.transformers.transformer_temporal.TransformerTemporalModelOutput dataclass

Bases: BaseOutput

The output of [TransformerTemporalModel].

PARAMETER DESCRIPTION
sample

The hidden states output conditioned on encoder_hidden_states input.

TYPE: `ms.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`

Source code in mindone/diffusers/models/transformers/transformer_temporal.py
29
30
31
32
33
34
35
36
37
38
39
@dataclass
class TransformerTemporalModelOutput(BaseOutput):
    """
    The output of [`TransformerTemporalModel`].

    Args:
        sample (`ms.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
            The hidden states output conditioned on `encoder_hidden_states` input.
    """

    sample: ms.Tensor