classLatteTransformer3DModel(ModelMixin,ConfigMixin):_supports_gradient_checkpointing=True""" A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: https://github.com/Vchitect/Latte Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input. out_channels (`int`, *optional*): The number of channels in the output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, *optional*): The size of the patches to use in the patch embedding layer. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether or not to use elementwise affine in normalization layers. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. caption_channels (`int`, *optional*): The number of channels in the caption embeddings. video_length (`int`, *optional*): The number of frames in the video-like data. """@register_to_configdef__init__(self,num_attention_heads:int=16,attention_head_dim:int=88,in_channels:Optional[int]=None,out_channels:Optional[int]=None,num_layers:int=1,dropout:float=0.0,cross_attention_dim:Optional[int]=None,attention_bias:bool=False,sample_size:int=64,patch_size:Optional[int]=None,activation_fn:str="geglu",num_embeds_ada_norm:Optional[int]=None,norm_type:str="layer_norm",norm_elementwise_affine:bool=True,norm_eps:float=1e-5,caption_channels:int=None,video_length:int=16,):super().__init__()inner_dim=num_attention_heads*attention_head_dim# 1. Define input layersself.height=sample_sizeself.width=sample_sizeinterpolation_scale=self.config.sample_size//64interpolation_scale=max(interpolation_scale,1)self.pos_embed=PatchEmbed(height=sample_size,width=sample_size,patch_size=patch_size,in_channels=in_channels,embed_dim=inner_dim,interpolation_scale=interpolation_scale,)self.patch_size=self.config.patch_size# 2. Define spatial transformers blocksself.transformer_blocks=nn.CellList([BasicTransformerBlock(inner_dim,num_attention_heads,attention_head_dim,dropout=dropout,cross_attention_dim=cross_attention_dim,activation_fn=activation_fn,num_embeds_ada_norm=num_embeds_ada_norm,attention_bias=attention_bias,norm_type=norm_type,norm_elementwise_affine=norm_elementwise_affine,norm_eps=norm_eps,)fordinrange(num_layers)])# 3. Define temporal transformers blocksself.temporal_transformer_blocks=nn.CellList([BasicTransformerBlock(inner_dim,num_attention_heads,attention_head_dim,dropout=dropout,cross_attention_dim=None,activation_fn=activation_fn,num_embeds_ada_norm=num_embeds_ada_norm,attention_bias=attention_bias,norm_type=norm_type,norm_elementwise_affine=norm_elementwise_affine,norm_eps=norm_eps,)fordinrange(num_layers)])# 4. Define output layersself.out_channels=in_channelsifout_channelsisNoneelseout_channelsself.norm_out=LayerNorm(inner_dim,elementwise_affine=False,eps=1e-6)self.scale_shift_table=ms.Parameter(ops.randn((2,inner_dim))/inner_dim**0.5)self.proj_out=nn.Dense(inner_dim,patch_size*patch_size*self.out_channels)# 5. Latte other blocks.self.adaln_single=AdaLayerNormSingle(inner_dim,use_additional_conditions=False)self.caption_projection=PixArtAlphaTextProjection(in_features=caption_channels,hidden_size=inner_dim)# define temporal positional embeddingtemp_pos_embed=get_1d_sincos_pos_embed_from_grid(inner_dim,ops.arange(0,video_length).unsqueeze(1).numpy())# 1152 hidden sizeself.temp_pos_embed=ms.Tensor.from_numpy(temp_pos_embed).float().unsqueeze(0)self.gradient_checkpointing=Falsedef_set_gradient_checkpointing(self,module,value=False):self.gradient_checkpointing=valuedefconstruct(self,hidden_states:ms.Tensor,timestep:Optional[ms.Tensor]=None,encoder_hidden_states:Optional[ms.Tensor]=None,encoder_attention_mask:Optional[ms.Tensor]=None,enable_temporal_attentions:bool=True,return_dict:bool=False,):""" The [`LatteTransformer3DModel`] forward method. Args: hidden_states shape `(batch size, channel, num_frame, height, width)`: Input `hidden_states`. timestep ( `ms.Tensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. encoder_attention_mask ( `ms.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batcheight, sequence_length)` True = keep, False = discard. * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. enable_temporal_attentions: (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. return_dict (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """# Reshape hidden statesbatch_size,channels,num_frame,height,width=hidden_states.shape# batch_size channels num_frame height width -> (batch_size * num_frame) channels height widthhidden_states=hidden_states.permute(0,2,1,3,4).reshape(-1,channels,height,width)# Inputheight,width=(hidden_states.shape[-2]//self.patch_size,hidden_states.shape[-1]//self.patch_size,)num_patches=height*widthhidden_states=self.pos_embed(hidden_states)# alrady add positional embeddingsadded_cond_kwargs={"resolution":None,"aspect_ratio":None}timestep,embedded_timestep=self.adaln_single(timestep,added_cond_kwargs=added_cond_kwargs,batch_size=batch_size,hidden_dtype=hidden_states.dtype)# Prepare text embeddings for spatial block# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_sizeencoder_hidden_states=self.caption_projection(encoder_hidden_states)# 3 120 1152encoder_hidden_states_spatial=encoder_hidden_states.repeat_interleave(num_frame,dim=0).view(-1,encoder_hidden_states.shape[-2],encoder_hidden_states.shape[-1])# Prepare timesteps for spatial and temporal blocktimestep_spatial=timestep.repeat_interleave(num_frame,dim=0).view(-1,timestep.shape[-1])timestep_temp=timestep.repeat_interleave(num_patches,dim=0).view(-1,timestep.shape[-1])# Spatial and temporal transformer blocksfori,(spatial_block,temp_block)inenumerate(zip(self.transformer_blocks,self.temporal_transformer_blocks)):hidden_states=spatial_block(hidden_states,None,# attention_maskencoder_hidden_states_spatial,encoder_attention_mask,timestep_spatial,None,# cross_attention_kwargsNone,# class_labels)ifenable_temporal_attentions:# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_sizehidden_states=hidden_states.reshape(batch_size,-1,hidden_states.shape[-2],hidden_states.shape[-1]).permute(0,2,1,3)hidden_states=hidden_states.reshape(-1,hidden_states.shape[-2],hidden_states.shape[-1])ifi==0andnum_frame>1:hidden_states=(hidden_states+self.temp_pos_embed).to(hidden_states.dtype)hidden_states=temp_block(hidden_states,None,# attention_maskNone,# encoder_hidden_statesNone,# encoder_attention_masktimestep_temp,None,# cross_attention_kwargsNone,# class_labels)# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_sizehidden_states=hidden_states.reshape(batch_size,-1,hidden_states.shape[-2],hidden_states.shape[-1]).permute(0,2,1,3)hidden_states=hidden_states.reshape(-1,hidden_states.shape[-2],hidden_states.shape[-1])embedded_timestep=embedded_timestep.repeat_interleave(num_frame,dim=0).view(-1,embedded_timestep.shape[-1])shift,scale=(self.scale_shift_table[None]+embedded_timestep[:,None]).chunk(2,axis=1)hidden_states=self.norm_out(hidden_states)# Modulationhidden_states=hidden_states*(1+scale)+shifthidden_states=self.proj_out(hidden_states)# unpatchifyifself.adaln_singleisNone:height=width=int(hidden_states.shape[1]**0.5)hidden_states=hidden_states.reshape((-1,height,width,self.patch_size,self.patch_size,self.out_channels))hidden_states=hidden_states.transpose(0,5,1,3,2,4)output=hidden_states.reshape((-1,self.out_channels,height*self.patch_size,width*self.patch_size))output=output.reshape(batch_size,-1,output.shape[-3],output.shape[-2],output.shape[-1]).permute(0,2,1,3,4)ifnotreturn_dict:return(output,)returnTransformer2DModelOutput(sample=output)
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
TYPE: `ms.Tensor`, *optional*DEFAULT:None
enable_temporal_attentions
(bool, optional, defaults to True): Whether to enable temporal attentions.
TYPE:boolDEFAULT:True
return_dict
Whether or not to return a [~models.unet_2d_condition.UNet2DConditionOutput] instead of a plain
tuple.
TYPE:`bool`, *optional*, defaults to `False`DEFAULT:False
RETURNS
DESCRIPTION
If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a
tuple where the first element is the sample tensor.
Source code in mindone/diffusers/models/transformers/latte_transformer_3d.py
defconstruct(self,hidden_states:ms.Tensor,timestep:Optional[ms.Tensor]=None,encoder_hidden_states:Optional[ms.Tensor]=None,encoder_attention_mask:Optional[ms.Tensor]=None,enable_temporal_attentions:bool=True,return_dict:bool=False,):""" The [`LatteTransformer3DModel`] forward method. Args: hidden_states shape `(batch size, channel, num_frame, height, width)`: Input `hidden_states`. timestep ( `ms.Tensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. encoder_attention_mask ( `ms.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batcheight, sequence_length)` True = keep, False = discard. * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. enable_temporal_attentions: (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. return_dict (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """# Reshape hidden statesbatch_size,channels,num_frame,height,width=hidden_states.shape# batch_size channels num_frame height width -> (batch_size * num_frame) channels height widthhidden_states=hidden_states.permute(0,2,1,3,4).reshape(-1,channels,height,width)# Inputheight,width=(hidden_states.shape[-2]//self.patch_size,hidden_states.shape[-1]//self.patch_size,)num_patches=height*widthhidden_states=self.pos_embed(hidden_states)# alrady add positional embeddingsadded_cond_kwargs={"resolution":None,"aspect_ratio":None}timestep,embedded_timestep=self.adaln_single(timestep,added_cond_kwargs=added_cond_kwargs,batch_size=batch_size,hidden_dtype=hidden_states.dtype)# Prepare text embeddings for spatial block# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_sizeencoder_hidden_states=self.caption_projection(encoder_hidden_states)# 3 120 1152encoder_hidden_states_spatial=encoder_hidden_states.repeat_interleave(num_frame,dim=0).view(-1,encoder_hidden_states.shape[-2],encoder_hidden_states.shape[-1])# Prepare timesteps for spatial and temporal blocktimestep_spatial=timestep.repeat_interleave(num_frame,dim=0).view(-1,timestep.shape[-1])timestep_temp=timestep.repeat_interleave(num_patches,dim=0).view(-1,timestep.shape[-1])# Spatial and temporal transformer blocksfori,(spatial_block,temp_block)inenumerate(zip(self.transformer_blocks,self.temporal_transformer_blocks)):hidden_states=spatial_block(hidden_states,None,# attention_maskencoder_hidden_states_spatial,encoder_attention_mask,timestep_spatial,None,# cross_attention_kwargsNone,# class_labels)ifenable_temporal_attentions:# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_sizehidden_states=hidden_states.reshape(batch_size,-1,hidden_states.shape[-2],hidden_states.shape[-1]).permute(0,2,1,3)hidden_states=hidden_states.reshape(-1,hidden_states.shape[-2],hidden_states.shape[-1])ifi==0andnum_frame>1:hidden_states=(hidden_states+self.temp_pos_embed).to(hidden_states.dtype)hidden_states=temp_block(hidden_states,None,# attention_maskNone,# encoder_hidden_statesNone,# encoder_attention_masktimestep_temp,None,# cross_attention_kwargsNone,# class_labels)# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_sizehidden_states=hidden_states.reshape(batch_size,-1,hidden_states.shape[-2],hidden_states.shape[-1]).permute(0,2,1,3)hidden_states=hidden_states.reshape(-1,hidden_states.shape[-2],hidden_states.shape[-1])embedded_timestep=embedded_timestep.repeat_interleave(num_frame,dim=0).view(-1,embedded_timestep.shape[-1])shift,scale=(self.scale_shift_table[None]+embedded_timestep[:,None]).chunk(2,axis=1)hidden_states=self.norm_out(hidden_states)# Modulationhidden_states=hidden_states*(1+scale)+shifthidden_states=self.proj_out(hidden_states)# unpatchifyifself.adaln_singleisNone:height=width=int(hidden_states.shape[1]**0.5)hidden_states=hidden_states.reshape((-1,height,width,self.patch_size,self.patch_size,self.out_channels))hidden_states=hidden_states.transpose(0,5,1,3,2,4)output=hidden_states.reshape((-1,self.out_channels,height*self.patch_size,width*self.patch_size))output=output.reshape(batch_size,-1,output.shape[-3],output.shape[-2],output.shape[-1]).permute(0,2,1,3,4)ifnotreturn_dict:return(output,)returnTransformer2DModelOutput(sample=output)