Skip to content

AutoencoderKLMagvit

The 3D variational autoencoder (VAE) model with KL loss used in EasyAnimate was introduced by Alibaba PAI.

The model can be loaded with the following code snippet.

from mindone.diffusers import AutoencoderKLMagvit

vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", mindspore_dtype=ms.float16).to("cuda")

mindone.diffusers.AutoencoderKLMagvit

Bases: ModelMixin, ConfigMixin

A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model is used in EasyAnimate.

This model inherits from [ModelMixin]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving).

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
    r"""
    A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
    model is used in [EasyAnimate](https://arxiv.org/abs/2405.18991).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        latent_channels: int = 16,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
        down_block_types: Tuple[str, ...] = [
            "SpatialDownBlock3D",
            "SpatialTemporalDownBlock3D",
            "SpatialTemporalDownBlock3D",
            "SpatialTemporalDownBlock3D",
        ],
        up_block_types: Tuple[str, ...] = [
            "SpatialUpBlock3D",
            "SpatialTemporalUpBlock3D",
            "SpatialTemporalUpBlock3D",
            "SpatialTemporalUpBlock3D",
        ],
        layers_per_block: int = 2,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        scaling_factor: float = 0.7125,
        spatial_group_norm: bool = True,
    ):
        super().__init__()

        # Initialize the encoder
        self.encoder = EasyAnimateEncoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            down_block_types=down_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            act_fn=act_fn,
            double_z=True,
            spatial_group_norm=spatial_group_norm,
        )

        # Initialize the decoder
        self.decoder = EasyAnimateDecoder(
            in_channels=latent_channels,
            out_channels=out_channels,
            up_block_types=up_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            act_fn=act_fn,
            spatial_group_norm=spatial_group_norm,
        )

        # Initialize convolution layers for quantization and post-quantization
        self.quant_conv = mint.nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
        self.post_quant_conv = mint.nn.Conv3d(latent_channels, latent_channels, kernel_size=1)

        self.diag_gauss_dist = DiagonalGaussianDistribution()

        self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
        self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2)

        # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
        # to perform decoding of a single video latent at a time.
        self.use_slicing = False

        # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
        # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
        # intermediate tiles together, the memory requirement can be lowered.
        self.use_tiling = False

        # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
        # at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
        self.use_framewise_encoding = False
        self.use_framewise_decoding = False

        # Assign mini-batch sizes for encoder and decoder
        self.num_sample_frames_batch_size = 4
        self.num_latent_frames_batch_size = 1

        # The minimal tile height and width for spatial tiling to be used
        self.tile_sample_min_height = 512
        self.tile_sample_min_width = 512
        self.tile_sample_min_num_frames = 4

        # The minimal distance between two spatial tiles
        self.tile_sample_stride_height = 448
        self.tile_sample_stride_width = 448
        self.tile_sample_stride_num_frames = 8

    def _clear_conv_cache(self):
        # Clear cache for convolutional layers if needed
        for name, module in self.name_cells().items():
            if isinstance(module, EasyAnimateCausalConv3d):
                module._clear_conv_cache()
            if isinstance(module, EasyAnimateUpsampler3D):
                module._clear_conv_cache()

    def enable_tiling(
        self,
        tile_sample_min_height: Optional[int] = None,
        tile_sample_min_width: Optional[int] = None,
        tile_sample_min_num_frames: Optional[int] = None,
        tile_sample_stride_height: Optional[float] = None,
        tile_sample_stride_width: Optional[float] = None,
        tile_sample_stride_num_frames: Optional[float] = None,
    ) -> None:
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.

        Args:
            tile_sample_min_height (`int`, *optional*):
                The minimum height required for a sample to be separated into tiles across the height dimension.
            tile_sample_min_width (`int`, *optional*):
                The minimum width required for a sample to be separated into tiles across the width dimension.
            tile_sample_stride_height (`int`, *optional*):
                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
                no tiling artifacts produced across the height dimension.
            tile_sample_stride_width (`int`, *optional*):
                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
                artifacts produced across the width dimension.
        """
        self.use_tiling = True
        self.use_framewise_decoding = True
        self.use_framewise_encoding = True
        self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
        self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
        self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
        self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
        self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

    def disable_tiling(self) -> None:
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_tiling = False

    def enable_slicing(self) -> None:
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self) -> None:
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    def _encode(
        self, x: ms.Tensor, return_dict: bool = False
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        Encode a batch of images into latents.

        Args:
            x (`ms.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded images. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width):
            return self.tiled_encode(x, return_dict=return_dict)

        first_frames = self.encoder(x[:, :, :1, :, :])
        h = [first_frames]
        for i in range(1, x.shape[2], self.num_sample_frames_batch_size):
            next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :])
            h.append(next_frames)
        h = mint.cat(h, dim=2)
        moments = self.quant_conv(h)

        self._clear_conv_cache()
        return moments

    def encode(
        self, x: ms.Tensor, return_dict: bool = False
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        Encode a batch of images into latents.

        Args:
            x (`ms.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            h = mint.cat(encoded_slices)
        else:
            h = self._encode(x)

        # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
        # posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (h,)
        # return AutoencoderKLOutput(latent_dist=h)
        return AutoencoderKLOutput(latent=h)

    def _decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        batch_size, num_channels, num_frames, height, width = z.shape
        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio

        if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
            return self.tiled_decode(z, return_dict=return_dict)

        z = self.post_quant_conv(z)

        # Process the first frame and save the result
        first_frames = self.decoder(z[:, :, :1, :, :])
        # Initialize the list to store the processed frames, starting with the first frame
        dec = [first_frames]
        # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
        for i in range(1, z.shape[2], self.num_latent_frames_batch_size):
            next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :])
            dec.append(next_frames)
        # Concatenate all processed frames along the channel dimension
        dec = mint.cat(dec, dim=2)

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        """
        Decode a batch of images.

        Args:
            z (`ms.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        """
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice)[0] for z_slice in z.split(1)]
            decoded = mint.cat(decoded_slices)
        else:
            decoded = self._decode(z)[0]

        self._clear_conv_cache()
        if not return_dict:
            return (decoded,)
        return DecoderOutput(sample=decoded)

    def blend_v(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor:
        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
        for y in range(blend_extent):
            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
                y / blend_extent
            )
        return b

    def blend_h(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor:
        blend_extent = min(a.shape[4], b.shape[4], blend_extent)
        for x in range(blend_extent):
            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
                x / blend_extent
            )
        return b

    def tiled_encode(self, x: ms.Tensor, return_dict: bool = False) -> AutoencoderKLOutput:
        batch_size, num_channels, num_frames, height, width = x.shape
        latent_height = height // self.spatial_compression_ratio
        latent_width = width // self.spatial_compression_ratio

        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

        blend_height = tile_latent_min_height - tile_latent_stride_height
        blend_width = tile_latent_min_width - tile_latent_stride_width

        # Split the image into 512x512 tiles and encode them separately.
        rows = []
        for i in range(0, height, self.tile_sample_stride_height):
            row = []
            for j in range(0, width, self.tile_sample_stride_width):
                tile = x[
                    :,
                    :,
                    :,
                    i : i + self.tile_sample_min_height,
                    j : j + self.tile_sample_min_width,
                ]

                first_frames = self.encoder(tile[:, :, 0:1, :, :])
                tile_h = [first_frames]
                for k in range(1, num_frames, self.num_sample_frames_batch_size):
                    next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :])
                    tile_h.append(next_frames)
                tile = mint.cat(tile_h, dim=2)
                tile = self.quant_conv(tile)
                self._clear_conv_cache()
                row.append(tile)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_width)
                result_row.append(tile[:, :, :, :latent_height, :latent_width])
            result_rows.append(mint.cat(result_row, dim=4))

        moments = mint.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
        return moments

    def tiled_decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        batch_size, num_channels, num_frames, height, width = z.shape
        sample_height = height * self.spatial_compression_ratio
        sample_width = width * self.spatial_compression_ratio

        tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
        tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
        tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
        tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

        blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
        blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

        # Split z into overlapping 64x64 tiles and decode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, height, tile_latent_stride_height):
            row = []
            for j in range(0, width, tile_latent_stride_width):
                tile = z[
                    :,
                    :,
                    :,
                    i : i + tile_latent_min_height,
                    j : j + tile_latent_min_width,
                ]
                tile = self.post_quant_conv(tile)

                # Process the first frame and save the result
                first_frames = self.decoder(tile[:, :, :1, :, :])
                # Initialize the list to store the processed frames, starting with the first frame
                tile_dec = [first_frames]
                # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
                for k in range(1, num_frames, self.num_latent_frames_batch_size):
                    next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :])
                    tile_dec.append(next_frames)
                # Concatenate all processed frames along the channel dimension
                decoded = mint.cat(tile_dec, dim=2)
                self._clear_conv_cache()
                row.append(decoded)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_height)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_width)
                result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
            result_rows.append(mint.cat(result_row, dim=4))

        dec = mint.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    def construct(
        self,
        sample: ms.Tensor,
        sample_posterior: bool = False,
        return_dict: bool = False,
        generator: Optional[np.random.Generator] = None,
    ) -> Union[DecoderOutput, ms.Tensor]:
        r"""
        Args:
            sample (`ms.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        x = sample
        posterior = self.encode(x)[0]
        if sample_posterior:
            z = self.diag_gauss_dist.sample(posterior, generator=generator)
        else:
            z = self.diag_gauss_dist.mode(posterior)
        dec = self.decode(z)[0]

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderKLMagvit.construct(sample, sample_posterior=False, return_dict=False, generator=None)

PARAMETER DESCRIPTION
sample

Input sample.

TYPE: `ms.Tensor`

sample_posterior

Whether to sample from the posterior.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

return_dict

Whether or not to return a [DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
def construct(
    self,
    sample: ms.Tensor,
    sample_posterior: bool = False,
    return_dict: bool = False,
    generator: Optional[np.random.Generator] = None,
) -> Union[DecoderOutput, ms.Tensor]:
    r"""
    Args:
        sample (`ms.Tensor`): Input sample.
        sample_posterior (`bool`, *optional*, defaults to `False`):
            Whether to sample from the posterior.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
    """
    x = sample
    posterior = self.encode(x)[0]
    if sample_posterior:
        z = self.diag_gauss_dist.sample(posterior, generator=generator)
    else:
        z = self.diag_gauss_dist.mode(posterior)
    dec = self.decode(z)[0]

    if not return_dict:
        return (dec,)

    return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderKLMagvit.decode(z, return_dict=False)

Decode a batch of images.

PARAMETER DESCRIPTION
z

Input batch of latent vectors.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.vae.DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

RETURNS DESCRIPTION
Union[DecoderOutput, Tensor]

[~models.vae.DecoderOutput] or tuple: If return_dict is True, a [~models.vae.DecoderOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
    """
    Decode a batch of images.

    Args:
        z (`ms.Tensor`): Input batch of latent vectors.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

    Returns:
        [`~models.vae.DecoderOutput`] or `tuple`:
            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
            returned.
    """
    if self.use_slicing and z.shape[0] > 1:
        decoded_slices = [self._decode(z_slice)[0] for z_slice in z.split(1)]
        decoded = mint.cat(decoded_slices)
    else:
        decoded = self._decode(z)[0]

    self._clear_conv_cache()
    if not return_dict:
        return (decoded,)
    return DecoderOutput(sample=decoded)

mindone.diffusers.AutoencoderKLMagvit.disable_slicing()

Disable sliced VAE decoding. If enable_slicing was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
829
830
831
832
833
834
def disable_slicing(self) -> None:
    r"""
    Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_slicing = False

mindone.diffusers.AutoencoderKLMagvit.disable_tiling()

Disable tiled VAE decoding. If enable_tiling was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
815
816
817
818
819
820
def disable_tiling(self) -> None:
    r"""
    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_tiling = False

mindone.diffusers.AutoencoderKLMagvit.enable_slicing()

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
822
823
824
825
826
827
def enable_slicing(self) -> None:
    r"""
    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.use_slicing = True

mindone.diffusers.AutoencoderKLMagvit.enable_tiling(tile_sample_min_height=None, tile_sample_min_width=None, tile_sample_min_num_frames=None, tile_sample_stride_height=None, tile_sample_stride_width=None, tile_sample_stride_num_frames=None)

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

PARAMETER DESCRIPTION
tile_sample_min_height

The minimum height required for a sample to be separated into tiles across the height dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_min_width

The minimum width required for a sample to be separated into tiles across the width dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_stride_height

The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are no tiling artifacts produced across the height dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_stride_width

The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling artifacts produced across the width dimension.

TYPE: `int`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
def enable_tiling(
    self,
    tile_sample_min_height: Optional[int] = None,
    tile_sample_min_width: Optional[int] = None,
    tile_sample_min_num_frames: Optional[int] = None,
    tile_sample_stride_height: Optional[float] = None,
    tile_sample_stride_width: Optional[float] = None,
    tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
    r"""
    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
    processing larger images.

    Args:
        tile_sample_min_height (`int`, *optional*):
            The minimum height required for a sample to be separated into tiles across the height dimension.
        tile_sample_min_width (`int`, *optional*):
            The minimum width required for a sample to be separated into tiles across the width dimension.
        tile_sample_stride_height (`int`, *optional*):
            The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
            no tiling artifacts produced across the height dimension.
        tile_sample_stride_width (`int`, *optional*):
            The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
            artifacts produced across the width dimension.
    """
    self.use_tiling = True
    self.use_framewise_decoding = True
    self.use_framewise_encoding = True
    self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
    self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
    self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
    self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
    self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
    self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

mindone.diffusers.AutoencoderKLMagvit.encode(x, return_dict=False)

Encode a batch of images into latents.

PARAMETER DESCRIPTION
x

Input batch of images.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.autoencoder_kl.AutoencoderKLOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

RETURNS DESCRIPTION
Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

The latent representations of the encoded videos. If return_dict is True, a

Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

[~models.autoencoder_kl.AutoencoderKLOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
def encode(
    self, x: ms.Tensor, return_dict: bool = False
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
    """
    Encode a batch of images into latents.

    Args:
        x (`ms.Tensor`): Input batch of images.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

    Returns:
            The latent representations of the encoded videos. If `return_dict` is True, a
            [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
    """
    if self.use_slicing and x.shape[0] > 1:
        encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
        h = mint.cat(encoded_slices)
    else:
        h = self._encode(x)

    # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
    # posterior = DiagonalGaussianDistribution(h)

    if not return_dict:
        return (h,)
    # return AutoencoderKLOutput(latent_dist=h)
    return AutoencoderKLOutput(latent=h)

mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput dataclass

Bases: BaseOutput

Output of AutoencoderKL encoding method.

PARAMETER DESCRIPTION
latent

Encoded outputs of Encoder represented as the mean and logvar of DiagonalGaussianDistribution. DiagonalGaussianDistribution allows for sampling latents from the distribution.

TYPE: `ms.Tensor`

Source code in mindone/diffusers/models/modeling_outputs.py
 8
 9
10
11
12
13
14
15
16
17
18
19
@dataclass
class AutoencoderKLOutput(BaseOutput):
    """
    Output of AutoencoderKL encoding method.

    Args:
        latent (`ms.Tensor`):
            Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
            `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
    """

    latent: ms.Tensor

mindone.diffusers.models.autoencoders.vae.DecoderOutput dataclass

Bases: BaseOutput

Output of decoding method.

PARAMETER DESCRIPTION
sample

The decoded output sample from the last layer of the model.

TYPE: `ms.Tensor` of shape `(batch_size, num_channels, height, width)`

Source code in mindone/diffusers/models/autoencoders/vae.py
44
45
46
47
48
49
50
51
52
53
54
55
@dataclass
class DecoderOutput(BaseOutput):
    r"""
    Output of decoding method.

    Args:
        sample (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`):
            The decoded output sample from the last layer of the model.
    """

    sample: ms.Tensor
    commit_loss: Optional[ms.Tensor] = None