classAllegroTransformer3DModel(ModelMixin,ConfigMixin):_supports_gradient_checkpointing=True""" A 3D Transformer model for video-like data. Args: patch_size (`int`, defaults to `2`): The size of spatial patches to use in the patch embedding layer. patch_size_t (`int`, defaults to `1`): The size of temporal patches to use in the patch embedding layer. num_attention_heads (`int`, defaults to `24`): The number of heads to use for multi-head attention. attention_head_dim (`int`, defaults to `96`): The number of channels in each head. in_channels (`int`, defaults to `4`): The number of channels in the input. out_channels (`int`, *optional*, defaults to `4`): The number of channels in the output. num_layers (`int`, defaults to `32`): The number of layers of Transformer blocks to use. dropout (`float`, defaults to `0.0`): The dropout probability to use. cross_attention_dim (`int`, defaults to `2304`): The dimension of the cross attention features. attention_bias (`bool`, defaults to `True`): Whether or not to use bias in the attention projection layers. sample_height (`int`, defaults to `90`): The height of the input latents. sample_width (`int`, defaults to `160`): The width of the input latents. sample_frames (`int`, defaults to `22`): The number of frames in the input latents. activation_fn (`str`, defaults to `"gelu-approximate"`): Activation function to use in feed-forward. norm_elementwise_affine (`bool`, defaults to `False`): Whether or not to use elementwise affine in normalization layers. norm_eps (`float`, defaults to `1e-6`): The epsilon value to use in normalization layers. caption_channels (`int`, defaults to `4096`): Number of channels to use for projecting the caption embeddings. interpolation_scale_h (`float`, defaults to `2.0`): Scaling factor to apply in 3D positional embeddings across height dimension. interpolation_scale_w (`float`, defaults to `2.0`): Scaling factor to apply in 3D positional embeddings across width dimension. interpolation_scale_t (`float`, defaults to `2.2`): Scaling factor to apply in 3D positional embeddings across time dimension. """@register_to_configdef__init__(self,patch_size:int=2,patch_size_t:int=1,num_attention_heads:int=24,attention_head_dim:int=96,in_channels:int=4,out_channels:int=4,num_layers:int=32,dropout:float=0.0,cross_attention_dim:int=2304,attention_bias:bool=True,sample_height:int=90,sample_width:int=160,sample_frames:int=22,activation_fn:str="gelu-approximate",norm_elementwise_affine:bool=False,norm_eps:float=1e-6,caption_channels:int=4096,interpolation_scale_h:float=2.0,interpolation_scale_w:float=2.0,interpolation_scale_t:float=2.2,):super().__init__()self.inner_dim=num_attention_heads*attention_head_diminterpolation_scale_t=(interpolation_scale_tifinterpolation_scale_tisnotNoneelse((sample_frames-1)//16+1)ifsample_frames%2==1elsesample_frames//16)interpolation_scale_h=interpolation_scale_hifinterpolation_scale_hisnotNoneelsesample_height/30interpolation_scale_w=interpolation_scale_wifinterpolation_scale_wisnotNoneelsesample_width/40# 1. Patch embeddingself.pos_embed=PatchEmbed(height=sample_height,width=sample_width,patch_size=patch_size,in_channels=in_channels,embed_dim=self.inner_dim,pos_embed_type=None,)# 2. Transformer blocksself.transformer_blocks=nn.CellList([AllegroTransformerBlock(self.inner_dim,num_attention_heads,attention_head_dim,dropout=dropout,cross_attention_dim=cross_attention_dim,activation_fn=activation_fn,attention_bias=attention_bias,norm_elementwise_affine=norm_elementwise_affine,norm_eps=norm_eps,)for_inrange(num_layers)])# 3. Output projection & normself.norm_out=LayerNorm(self.inner_dim,elementwise_affine=False,eps=1e-6)self.scale_shift_table=ms.Parameter(ops.randn(2,self.inner_dim)/self.inner_dim**0.5,name="scale_shift_table")self.proj_out=nn.Dense(self.inner_dim,patch_size*patch_size*out_channels)# 4. Timestep embeddingsself.adaln_single=AdaLayerNormSingle(self.inner_dim,use_additional_conditions=False)# 5. Caption projectionself.caption_projection=PixArtAlphaTextProjection(in_features=caption_channels,hidden_size=self.inner_dim)self.gradient_checkpointing=Falsedef_set_gradient_checkpointing(self,module,value=False):self.gradient_checkpointing=valuedefconstruct(self,hidden_states:ms.Tensor,encoder_hidden_states:ms.Tensor,timestep:ms.Tensor,attention_mask:Optional[ms.Tensor]=None,encoder_attention_mask:Optional[ms.Tensor]=None,image_rotary_emb:Optional[Tuple[ms.Tensor,ms.Tensor]]=None,return_dict:bool=False,):batch_size,num_channels,num_frames,height,width=hidden_states.shapep_t=self.config["patch_size_t"]p=self.config["patch_size"]post_patch_num_frames=num_frames//p_tpost_patch_height=height//ppost_patch_width=width//p# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.# expects mask of shape:# [batch, key_tokens]# adds singleton query_tokens dimension:# [batch, 1, key_tokens]# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, Noneifattention_maskisnotNoneandattention_mask.ndim==4:# assume that mask is expressed as:# (1 = keep, 0 = discard)# convert mask into a bias that can be added to attention scores:# (keep = +0, discard = -10000.0)# b, frame+use_image_num, h, w -> a video with images# b, 1, h, w -> only imagesattention_mask=attention_mask.to(hidden_states.dtype)attention_mask=attention_mask[:,:num_frames]# [batch_size, num_frames, height, width]ifattention_mask.numel()>0:attention_mask=attention_mask.unsqueeze(1)# [batch_size, 1, num_frames, height, width]attention_mask=ops.max_pool3d(attention_mask,kernel_size=(p_t,p,p),stride=(p_t,p,p))attention_mask=attention_mask.flatten(start_dim=1).view(batch_size,1,-1)attention_mask=((1-attention_mask.bool().to(hidden_states.dtype))*-10000.0ifattention_mask.numel()>0elseNone)# convert encoder_attention_mask to a bias the same way we do for attention_maskifencoder_attention_maskisnotNoneandencoder_attention_mask.ndim==2:encoder_attention_mask=(1-encoder_attention_mask.to(self.dtype))*-10000.0encoder_attention_mask=encoder_attention_mask.unsqueeze(1)# 1. Timestep embeddingstimestep,embedded_timestep=self.adaln_single(timestep,batch_size=batch_size,hidden_dtype=hidden_states.dtype)# 2. Patch embeddingshidden_states=hidden_states.permute(0,2,1,3,4).flatten(start_dim=0,end_dim=1)hidden_states=self.pos_embed(hidden_states)hidden_states=hidden_states.reshape(hidden_states.shape[:0]+(batch_size,-1)+hidden_states.shape[1:]).flatten(start_dim=1,end_dim=2)encoder_hidden_states=self.caption_projection(encoder_hidden_states)encoder_hidden_states=encoder_hidden_states.view(batch_size,-1,encoder_hidden_states.shape[-1])# 3. Transformer blocksfori,blockinenumerate(self.transformer_blocks):hidden_states=block(hidden_states=hidden_states,encoder_hidden_states=encoder_hidden_states,temb=timestep,attention_mask=attention_mask,encoder_attention_mask=encoder_attention_mask,image_rotary_emb=image_rotary_emb,)# 4. Output normalization & projectionshift,scale=(self.scale_shift_table[None]+embedded_timestep[:,None]).chunk(2,axis=1)hidden_states=self.norm_out(hidden_states)# Modulationhidden_states=hidden_states*(1+scale)+shifthidden_states=self.proj_out(hidden_states)# If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged. This function is not supported in MS.ifhidden_states.ndim==4:hidden_states=hidden_states.squeeze(1)# 5. Unpatchifyhidden_states=hidden_states.reshape(batch_size,post_patch_num_frames,post_patch_height,post_patch_width,p_t,p,p,-1)hidden_states=hidden_states.permute(0,7,1,4,2,5,3,6)output=hidden_states.reshape(batch_size,-1,num_frames,height,width)ifnotreturn_dict:return(output,)returnTransformer2DModelOutput(sample=output)
The hidden states output conditioned on the encoder_hidden_states input. If discrete, returns probability
distributions for the unnoised latent pixels.
TYPE:size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete
Source code in mindone/diffusers/models/modeling_outputs.py
22232425262728293031323334
@dataclassclassTransformer2DModelOutput(BaseOutput):""" The output of [`Transformer2DModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """sample:"ms.Tensor"# noqa: F821