Skip to content

AutoencoderKLWan

The 3D variational autoencoder (VAE) model with KL loss used in Wan 2.1 by the Alibaba Wan Team.

The model can be loaded with the following code snippet.

from mindone.diffusers import AutoencoderKLWan

vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", mindspore_dtype=ms.float32)

mindone.diffusers.AutoencoderKLWan

Bases: ModelMixin, ConfigMixin, FromOriginalModelMixin

A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [Wan 2.1].

This model inherits from [ModelMixin]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving).

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    r"""
    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
    Introduced in [Wan 2.1].

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    """

    _supports_gradient_checkpointing = False

    @register_to_config
    def __init__(
        self,
        base_dim: int = 96,
        z_dim: int = 16,
        dim_mult: Tuple[int] = [1, 2, 4, 4],
        num_res_blocks: int = 2,
        attn_scales: List[float] = [],
        temperal_downsample: List[bool] = [False, True, True],
        dropout: float = 0.0,
        latents_mean: List[float] = [
            -0.7571,
            -0.7089,
            -0.9113,
            0.1075,
            -0.1745,
            0.9653,
            -0.1517,
            1.5508,
            0.4134,
            -0.0715,
            0.5517,
            -0.3632,
            -0.1922,
            -0.9497,
            0.2503,
            -0.2921,
        ],
        latents_std: List[float] = [
            2.8184,
            1.4541,
            2.3275,
            2.6558,
            1.2196,
            1.7708,
            2.6052,
            2.0743,
            3.2687,
            2.1526,
            2.8652,
            1.5579,
            1.6382,
            1.1253,
            2.8251,
            1.9160,
        ],
    ) -> None:
        super().__init__()

        self.z_dim = z_dim
        self.temperal_downsample = temperal_downsample
        self.temperal_upsample = temperal_downsample[::-1]

        self.encoder = WanEncoder3d(
            base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
        )
        self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
        self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)

        self.decoder = WanDecoder3d(
            base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
        )

        self.diag_gauss_dist = DiagonalGaussianDistribution()

    def clear_cache(self):
        def _count_conv3d(model):
            count = 0
            for _, m in model.cells_and_names():
                if isinstance(m, WanCausalConv3d):
                    count += 1
            return count

        self._conv_num = _count_conv3d(self.decoder)
        self._conv_idx = [0]
        self._feat_map = [None] * self._conv_num
        # cache encode
        self._enc_conv_num = _count_conv3d(self.encoder)
        self._enc_conv_idx = [0]
        self._enc_feat_map = [None] * self._enc_conv_num

    def _encode(self, x: ms.Tensor) -> ms.Tensor:
        self.clear_cache()
        # cache
        t = x.shape[2]
        iter_ = 1 + (t - 1) // 4
        for i in range(iter_):
            self._enc_conv_idx = [0]
            if i == 0:
                out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
            else:
                out_ = self.encoder(
                    x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
                    feat_cache=self._enc_feat_map,
                    feat_idx=self._enc_conv_idx,
                )
                out = mint.cat([out, out_], 2)

        enc = self.quant_conv(out)
        mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
        enc = mint.cat([mu, logvar], dim=1)
        self.clear_cache()
        return enc

    def encode(
        self, x: ms.Tensor, return_dict: bool = False
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        r"""
        Encode a batch of images into latents.

        Args:
            x (`ms.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        h = self._encode(x)

        # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
        # posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (h,)
        return AutoencoderKLOutput(latent_dist=h)

    def _decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        self.clear_cache()

        iter_ = z.shape[2]
        x = self.post_quant_conv(z)
        for i in range(iter_):
            self._conv_idx = [0]
            if i == 0:
                out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
            else:
                out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
                out = mint.cat([out, out_], 2)

        out = mint.clamp(out, min=-1.0, max=1.0)
        self.clear_cache()
        if not return_dict:
            return (out,)

        return DecoderOutput(sample=out)

    def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        r"""
        Decode a batch of images.

        Args:
            z (`ms.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        """
        decoded = self._decode(z)[0]
        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)

    def construct(
        self,
        sample: ms.Tensor,
        sample_posterior: bool = False,
        return_dict: bool = False,
        generator: Optional[np.random.Generator] = None,
    ) -> Union[DecoderOutput, ms.Tensor]:
        """
        Args:
            sample (`ms.Tensor`): Input sample.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        x = sample
        posterior = self.encode(x)[0]
        if sample_posterior:
            z = self.diag_gauss_dist.sample(posterior, generator=generator)
        else:
            z = self.diag_gauss_dist.mode(posterior)
        dec = self.decode(z, return_dict=return_dict)
        return dec

mindone.diffusers.AutoencoderKLWan.construct(sample, sample_posterior=False, return_dict=False, generator=None)

PARAMETER DESCRIPTION
sample

Input sample.

TYPE: `ms.Tensor`

return_dict

Whether or not to return a [DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def construct(
    self,
    sample: ms.Tensor,
    sample_posterior: bool = False,
    return_dict: bool = False,
    generator: Optional[np.random.Generator] = None,
) -> Union[DecoderOutput, ms.Tensor]:
    """
    Args:
        sample (`ms.Tensor`): Input sample.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
    """
    x = sample
    posterior = self.encode(x)[0]
    if sample_posterior:
        z = self.diag_gauss_dist.sample(posterior, generator=generator)
    else:
        z = self.diag_gauss_dist.mode(posterior)
    dec = self.decode(z, return_dict=return_dict)
    return dec

mindone.diffusers.AutoencoderKLWan.decode(z, return_dict=False)

Decode a batch of images.

PARAMETER DESCRIPTION
z

Input batch of latent vectors.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.vae.DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[DecoderOutput, Tensor]

[~models.vae.DecoderOutput] or tuple: If return_dict is True, a [~models.vae.DecoderOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
    r"""
    Decode a batch of images.

    Args:
        z (`ms.Tensor`): Input batch of latent vectors.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

    Returns:
        [`~models.vae.DecoderOutput`] or `tuple`:
            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
            returned.
    """
    decoded = self._decode(z)[0]
    if not return_dict:
        return (decoded,)

    return DecoderOutput(sample=decoded)

mindone.diffusers.AutoencoderKLWan.encode(x, return_dict=False)

Encode a batch of images into latents.

PARAMETER DESCRIPTION
x

Input batch of images.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.autoencoder_kl.AutoencoderKLOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

The latent representations of the encoded videos. If return_dict is True, a

Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

[~models.autoencoder_kl.AutoencoderKLOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def encode(
    self, x: ms.Tensor, return_dict: bool = False
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
    r"""
    Encode a batch of images into latents.

    Args:
        x (`ms.Tensor`): Input batch of images.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

    Returns:
            The latent representations of the encoded videos. If `return_dict` is True, a
            [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
    """
    h = self._encode(x)

    # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
    # posterior = DiagonalGaussianDistribution(h)

    if not return_dict:
        return (h,)
    return AutoencoderKLOutput(latent_dist=h)

mindone.diffusers.models.autoencoders.vae.DecoderOutput dataclass

Bases: BaseOutput

Output of decoding method.

PARAMETER DESCRIPTION
sample

The decoded output sample from the last layer of the model.

TYPE: `ms.Tensor` of shape `(batch_size, num_channels, height, width)`

Source code in mindone/diffusers/models/autoencoders/vae.py
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class DecoderOutput(BaseOutput):
    r"""
    Output of decoding method.

    Args:
        sample (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`):
            The decoded output sample from the last layer of the model.
    """

    sample: ms.Tensor
    commit_loss: Optional[ms.Tensor] = None