Skip to content

Limitations

Due to differences in framework, some APIs & models will not be identical to huggingface/diffusers in the foreseeable future.

APIs

xxx.from_pretrained

  • torch_dtype is renamed to mindspore_dtype
  • device_map, max_memory, offload_folder, offload_state_dict, low_cpu_mem_usage will not be supported.

BaseOutput

  • Default value of return_dict is changed to False, for GRAPH_MODE does not allow to construct an instance of it.

Output of AutoencoderKL.encode

Unlike the output posterior = DiagonalGaussianDistribution(latent), which can do sampling by posterior.sample(). We can only output the latent and then do sampling through AutoencoderKL.diag_gauss_dist.sample(latent).

Models

The table below represents the current support in mindone/diffusers for each of those modules, whether they have support in Pynative fp16 mode, Graph fp16 mode, Pynative fp32 mode or Graph fp32 mode.

Names Pynative FP16 Pynative FP32 Graph FP16 Graph FP32 Description
StableCascadeUNet โŒ โœ… โŒ โœ… huggingface/diffusers output NaN when using float16.
nn.Conv3d โœ… โŒ โœ… โŒ FP32 is not supported on Ascend
TemporalConvLayer โœ… โŒ โœ… โŒ contains nn.Conv3d
TemporalResnetBlock โœ… โŒ โœ… โŒ contains nn.Conv3d
SpatioTemporalResBlock โœ… โŒ โœ… โŒ contains TemporalResnetBlock
UNetMidBlock3DCrossAttn โœ… โŒ โœ… โŒ contains TemporalConvLayer
CrossAttnDownBlock3D โœ… โŒ โœ… โŒ contains TemporalConvLayer
DownBlock3D โœ… โŒ โœ… โŒ contains TemporalConvLayer
CrossAttnUpBlock3D โœ… โŒ โœ… โŒ contains TemporalConvLayer
UpBlock3D โœ… โŒ โœ… โŒ contains TemporalConvLayer
MidBlockTemporalDecoder โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
UpBlockTemporalDecoder โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
UNetMidBlockSpatioTemporal โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
DownBlockSpatioTemporal โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
CrossAttnDownBlockSpatioTemporal โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
UpBlockSpatioTemporal โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
CrossAttnUpBlockSpatioTemporal โœ… โŒ โœ… โŒ contains SpatioTemporalResBlock
TemporalDecoder โœ… โŒ โœ… โŒ contains nn.Conv3d, MidBlockTemporalDecoder etc.
UNet3DConditionModel โœ… โŒ โœ… โŒ contains UNetMidBlock3DCrossAttn etc.
I2VGenXLUNet โœ… โŒ โœ… โŒ contains UNetMidBlock3DCrossAttn etc.
AutoencoderKLTemporalDecoder โœ… โŒ โœ… โŒ contains MidBlockTemporalDecoder etc.
UNetSpatioTemporalConditionModel โœ… โŒ โœ… โŒ contains UNetMidBlockSpatioTemporal etc.
FirUpsample2D โŒ โœ… โœ… โœ… ops.Conv2D has poor precision in fp16 and PyNative mode
FirDownsample2D โŒ โœ… โœ… โœ… ops.Conv2D has poor precision in fp16 and PyNative mode
AttnSkipUpBlock2D โŒ โœ… โœ… โœ… contains FirUpsample2D
SkipUpBlock2D โŒ โœ… โœ… โœ… contains FirUpsample2D
AttnSkipDownBlock2D โŒ โœ… โœ… โœ… contains FirDownsample2D
SkipDownBlock2D โŒ โœ… โœ… โœ… contains FirDownsample2D
ResnetBlock2D (kernel='fir') โŒ โœ… โœ… โœ… ops.Conv2D has poor precision in fp16 and PyNative mode