Skip to content

AutoencoderKLAllegro

The 3D variational autoencoder (VAE) model with KL loss used in Allegro was introduced in Allegro: Open the Black Box of Commercial-Level Video Generation Model by RhymesAI.

The model can be loaded with the following code snippet.

from mindone.diffusers import AutoencoderKLAllegro

vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", mindspore_dtype=ms.float32)

mindone.diffusers.AutoencoderKLAllegro

Bases: ModelMixin, ConfigMixin

A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in Allegro.

This model inherits from [ModelMixin]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving).

PARAMETER DESCRIPTION
in_channels

Number of channels in the input image.

TYPE: int, defaults to `3` DEFAULT: 3

out_channels

Number of channels in the output.

TYPE: int, defaults to `3` DEFAULT: 3

down_block_types

Tuple of strings denoting which types of down blocks to use.

TYPE: `Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")` DEFAULT: ('AllegroDownBlock3D', 'AllegroDownBlock3D', 'AllegroDownBlock3D', 'AllegroDownBlock3D')

up_block_types

Tuple of strings denoting which types of up blocks to use.

TYPE: `Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")` DEFAULT: ('AllegroUpBlock3D', 'AllegroUpBlock3D', 'AllegroUpBlock3D', 'AllegroUpBlock3D')

block_out_channels

Tuple of integers denoting number of output channels in each block.

TYPE: `Tuple[int, ...]`, defaults to `(128, 256, 512, 512)` DEFAULT: (128, 256, 512, 512)

temporal_downsample_blocks

Tuple of booleans denoting which blocks to enable temporal downsampling in.

TYPE: `Tuple[bool, ...]`, defaults to `(True, True, False, False)` DEFAULT: (True, True, False, False)

latent_channels

Number of channels in latents.

TYPE: `int`, defaults to `4` DEFAULT: 4

layers_per_block

Number of resnet or attention or temporal convolution layers per down/up block.

TYPE: `int`, defaults to `2` DEFAULT: 2

act_fn

The activation function to use.

TYPE: `str`, defaults to `"silu"` DEFAULT: 'silu'

norm_num_groups

Number of groups to use in normalization layers.

TYPE: `int`, defaults to `32` DEFAULT: 32

temporal_compression_ratio

Ratio by which temporal dimension of samples are compressed.

TYPE: `int`, defaults to `4` DEFAULT: 4

sample_size

Default latent size.

TYPE: `int`, defaults to `320` DEFAULT: 320

scaling_factor

The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula z = z * scaling_factor before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: z = 1 / scaling_factor * z. For more details, refer to sections 4.3.2 and D.1 of the High-Resolution Image Synthesis with Latent Diffusion Models paper.

TYPE: `float`, defaults to `0.13235` DEFAULT: 0.13

force_upcast

If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case force_upcast can be set to False - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix

TYPE: `bool`, default to `True` DEFAULT: True

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
    r"""
    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
    [Allegro](https://github.com/rhymes-ai/Allegro).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).

    Parameters:
        in_channels (int, defaults to `3`):
            Number of channels in the input image.
        out_channels (int, defaults to `3`):
            Number of channels in the output.
        down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
            Tuple of strings denoting which types of down blocks to use.
        up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
            Tuple of strings denoting which types of up blocks to use.
        block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
            Tuple of integers denoting number of output channels in each block.
        temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
            Tuple of booleans denoting which blocks to enable temporal downsampling in.
        latent_channels (`int`, defaults to `4`):
            Number of channels in latents.
        layers_per_block (`int`, defaults to `2`):
            Number of resnet or attention or temporal convolution layers per down/up block.
        act_fn (`str`, defaults to `"silu"`):
            The activation function to use.
        norm_num_groups (`int`, defaults to `32`):
            Number of groups to use in normalization layers.
        temporal_compression_ratio (`int`, defaults to `4`):
            Ratio by which temporal dimension of samples are compressed.
        sample_size (`int`, defaults to `320`):
            Default latent size.
        scaling_factor (`float`, defaults to `0.13235`):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
        force_upcast (`bool`, default to `True`):
            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
            can be fine-tuned / trained to a lower range without loosing too much precision in which case
            `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str, ...] = (
            "AllegroDownBlock3D",
            "AllegroDownBlock3D",
            "AllegroDownBlock3D",
            "AllegroDownBlock3D",
        ),
        up_block_types: Tuple[str, ...] = (
            "AllegroUpBlock3D",
            "AllegroUpBlock3D",
            "AllegroUpBlock3D",
            "AllegroUpBlock3D",
        ),
        block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
        temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
        temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
        latent_channels: int = 4,
        layers_per_block: int = 2,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        temporal_compression_ratio: float = 4,
        sample_size: int = 320,
        scaling_factor: float = 0.13,
        force_upcast: bool = True,
    ) -> None:
        super().__init__()

        self.encoder = AllegroEncoder3D(
            in_channels=in_channels,
            out_channels=latent_channels,
            down_block_types=down_block_types,
            temporal_downsample_blocks=temporal_downsample_blocks,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            act_fn=act_fn,
            norm_num_groups=norm_num_groups,
            double_z=True,
        )
        self.decoder = AllegroDecoder3D(
            in_channels=latent_channels,
            out_channels=out_channels,
            up_block_types=up_block_types,
            temporal_upsample_blocks=temporal_upsample_blocks,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            act_fn=act_fn,
        )
        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1, has_bias=True, pad_mode="valid")
        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1, has_bias=True, pad_mode="valid")
        self.diag_gauss_dist = DiagonalGaussianDistribution()

        # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
        # to use a specific parameter here or in other VAEs.

        self.use_slicing = False
        self.use_tiling = False

        self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
        self.tile_overlap_t = 8
        self.tile_overlap_h = 120
        self.tile_overlap_w = 80
        sample_frames = 24

        self.kernel = (sample_frames, sample_size, sample_size)
        self.stride = (
            sample_frames - self.tile_overlap_t,
            sample_size - self.tile_overlap_h,
            sample_size - self.tile_overlap_w,
        )

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
            module.gradient_checkpointing = value

    def enable_tiling(self) -> None:
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
        """
        self.use_tiling = True

    def disable_tiling(self) -> None:
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_tiling = False

    def enable_slicing(self) -> None:
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self) -> None:
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    def _encode(self, x: ms.Tensor) -> ms.Tensor:
        # TODO(aryan)
        # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
        if self.use_tiling:
            return self.tiled_encode(x)

        raise NotImplementedError("Encoding without tiling has not been implemented yet.")

    def encode(
        self, x: ms.Tensor, return_dict: bool = False
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        r"""
        Encode a batch of videos into latents.

        Args:
            x (`ms.Tensor`):
                Input batch of videos.
            return_dict (`bool`, defaults to `False`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            h = ops.cat(encoded_slices)
        else:
            h = self._encode(x)

        # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
        # posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (h,)
        return AutoencoderKLOutput(latent_dist=h)

    def _decode(self, z: ms.Tensor) -> ms.Tensor:
        # TODO(aryan): refactor tiling implementation
        # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
        if self.use_tiling:
            return self.tiled_decode(z)

        raise NotImplementedError("Decoding without tiling has not been implemented yet.")

    def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
        """
        Decode a batch of videos.

        Args:
            z (`ms.Tensor`):
                Input batch of latent vectors.
            return_dict (`bool`, defaults to `False`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        """
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
            decoded = ops.cat(decoded_slices)
        else:
            decoded = self._decode(z)

        if not return_dict:
            return (decoded,)
        return DecoderOutput(sample=decoded)

    def tiled_encode(self, x: ms.Tensor) -> ms.Tensor:
        local_batch_size = 1
        rs = self.spatial_compression_ratio
        rt = self.config["temporal_compression_ratio"]

        batch_size, num_channels, num_frames, height, width = x.shape

        output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
        output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
        output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1

        count = 0
        output_latent = x.new_zeros(
            (
                output_num_frames * output_height * output_width,
                2 * self.config["latent_channels"],
                self.kernel[0] // rt,
                self.kernel[1] // rs,
                self.kernel[2] // rs,
            ),
            dtype=x.dtype,
        )
        vae_batch_input = x.new_zeros(
            (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]), dtype=x.dtype
        )

        for i in range(output_num_frames):
            for j in range(output_height):
                for k in range(output_width):
                    n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
                    h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
                    w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]

                    video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
                    vae_batch_input[count % local_batch_size] = video_cube

                    if (
                        count % local_batch_size == local_batch_size - 1
                        or count == output_num_frames * output_height * output_width - 1
                    ):
                        latent = self.encoder(vae_batch_input)

                        if (
                            count == output_num_frames * output_height * output_width - 1
                            and count % local_batch_size != local_batch_size - 1
                        ):
                            output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
                        else:
                            output_latent[count - local_batch_size + 1 : count + 1] = latent

                        vae_batch_input = x.new_zeros(
                            (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]),
                            dtype=x.dtype,
                        )

                    count += 1

        latent = x.new_zeros(
            (batch_size, 2 * self.config["latent_channels"], num_frames // rt, height // rs, width // rs),
            dtype=x.dtype,
        )
        output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
        output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
        output_overlap = (
            output_kernel[0] - output_stride[0],
            output_kernel[1] - output_stride[1],
            output_kernel[2] - output_stride[2],
        )

        for i in range(output_num_frames):
            n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
            for j in range(output_height):
                h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
                for k in range(output_width):
                    w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
                    latent_mean = _prepare_for_blend(
                        (i, output_num_frames, output_overlap[0]),
                        (j, output_height, output_overlap[1]),
                        (k, output_width, output_overlap[2]),
                        output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
                    )
                    latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean

        latent = latent.permute(0, 2, 1, 3, 4).flatten(start_dim=0, end_dim=1)
        latent = self.quant_conv(latent)
        latent = latent.reshape(latent.shape[:0] + (batch_size, -1) + latent.shape[1:]).permute(0, 2, 1, 3, 4)
        return latent

    def tiled_decode(self, z: ms.Tensor) -> ms.Tensor:
        local_batch_size = 1
        rs = self.spatial_compression_ratio
        rt = self.config["temporal_compression_ratio"]

        latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
        latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs

        batch_size, num_channels, num_frames, height, width = z.shape

        # post quant conv (a mapping)
        z = z.permute(0, 2, 1, 3, 4).flatten(start_dim=0, end_dim=1)
        z = self.post_quant_conv(z)
        z = z.reshape(z.shape[:0] + (batch_size, -1) + z.shape[1:]).permute(0, 2, 1, 3, 4)

        output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
        output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
        output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1

        count = 0
        decoded_videos = z.new_zeros(
            (
                output_num_frames * output_height * output_width,
                self.config["out_channels"],
                self.kernel[0],
                self.kernel[1],
                self.kernel[2],
            ),
            dtype=z.dtype,
        )
        vae_batch_input = z.new_zeros(
            (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]), dtype=z.dtype
        )

        for i in range(output_num_frames):
            for j in range(output_height):
                for k in range(output_width):
                    n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
                    h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
                    w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]

                    current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
                    # In MS, if x.shape == y.shape, x[0] = y will result in an error, so we add `squeeze`.
                    vae_batch_input[count % local_batch_size] = current_latent.squeeze()

                    if (
                        count % local_batch_size == local_batch_size - 1
                        or count == output_num_frames * output_height * output_width - 1
                    ):
                        current_video = self.decoder(vae_batch_input)

                        if (
                            count == output_num_frames * output_height * output_width - 1
                            and count % local_batch_size != local_batch_size - 1
                        ):
                            decoded_videos[count - count % local_batch_size :] = current_video[
                                : count % local_batch_size + 1
                            ]
                        else:
                            decoded_videos[count - local_batch_size + 1 : count + 1] = current_video

                        vae_batch_input = z.new_zeros(
                            (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]),
                            dtype=z.dtype,
                        )

                    count += 1

        video = z.new_zeros(
            (batch_size, self.config["out_channels"], num_frames * rt, height * rs, width * rs), dtype=z.dtype
        )
        video_overlap = (
            self.kernel[0] - self.stride[0],
            self.kernel[1] - self.stride[1],
            self.kernel[2] - self.stride[2],
        )

        for i in range(output_num_frames):
            n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
            for j in range(output_height):
                h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
                for k in range(output_width):
                    w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
                    out_video_blend = _prepare_for_blend(
                        (i, output_num_frames, video_overlap[0]),
                        (j, output_height, video_overlap[1]),
                        (k, output_width, video_overlap[2]),
                        decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
                    )
                    video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend

        video = video.permute(0, 2, 1, 3, 4).contiguous()
        return video

    def construct(
        self,
        sample: ms.Tensor,
        sample_posterior: bool = False,
        return_dict: bool = False,
        generator: Optional[np.random.Generator] = None,
    ) -> Union[DecoderOutput, ms.Tensor]:
        r"""
        Args:
            sample (`ms.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
            generator (`np.random.Generator`, *optional*):
                NumPy random number generator.
        """
        x = sample
        latent = self.encode(x)[0]
        if sample_posterior:
            z = self.diag_gauss_dist.sample(latent, generator=generator)
        else:
            z = self.diag_gauss_dist.mode(latent)
        dec = self.decode(z)[0]

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderKLAllegro.construct(sample, sample_posterior=False, return_dict=False, generator=None)

PARAMETER DESCRIPTION
sample

Input sample.

TYPE: `ms.Tensor`

sample_posterior

Whether to sample from the posterior.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

return_dict

Whether or not to return a [DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

generator

NumPy random number generator.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
def construct(
    self,
    sample: ms.Tensor,
    sample_posterior: bool = False,
    return_dict: bool = False,
    generator: Optional[np.random.Generator] = None,
) -> Union[DecoderOutput, ms.Tensor]:
    r"""
    Args:
        sample (`ms.Tensor`): Input sample.
        sample_posterior (`bool`, *optional*, defaults to `False`):
            Whether to sample from the posterior.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        generator (`np.random.Generator`, *optional*):
            NumPy random number generator.
    """
    x = sample
    latent = self.encode(x)[0]
    if sample_posterior:
        z = self.diag_gauss_dist.sample(latent, generator=generator)
    else:
        z = self.diag_gauss_dist.mode(latent)
    dec = self.decode(z)[0]

    if not return_dict:
        return (dec,)

    return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderKLAllegro.decode(z, return_dict=False)

Decode a batch of videos.

PARAMETER DESCRIPTION
z

Input batch of latent vectors.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.vae.DecoderOutput] instead of a plain tuple.

TYPE: `bool`, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[DecoderOutput, Tensor]

[~models.vae.DecoderOutput] or tuple: If return_dict is True, a [~models.vae.DecoderOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, ms.Tensor]:
    """
    Decode a batch of videos.

    Args:
        z (`ms.Tensor`):
            Input batch of latent vectors.
        return_dict (`bool`, defaults to `False`):
            Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

    Returns:
        [`~models.vae.DecoderOutput`] or `tuple`:
            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
            returned.
    """
    if self.use_slicing and z.shape[0] > 1:
        decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
        decoded = ops.cat(decoded_slices)
    else:
        decoded = self._decode(z)

    if not return_dict:
        return (decoded,)
    return DecoderOutput(sample=decoded)

mindone.diffusers.AutoencoderKLAllegro.disable_slicing()

Disable sliced VAE decoding. If enable_slicing was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
899
900
901
902
903
904
def disable_slicing(self) -> None:
    r"""
    Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_slicing = False

mindone.diffusers.AutoencoderKLAllegro.disable_tiling()

Disable tiled VAE decoding. If enable_tiling was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
885
886
887
888
889
890
def disable_tiling(self) -> None:
    r"""
    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_tiling = False

mindone.diffusers.AutoencoderKLAllegro.enable_slicing()

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
892
893
894
895
896
897
def enable_slicing(self) -> None:
    r"""
    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.use_slicing = True

mindone.diffusers.AutoencoderKLAllegro.enable_tiling()

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
877
878
879
880
881
882
883
def enable_tiling(self) -> None:
    r"""
    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
    processing larger images.
    """
    self.use_tiling = True

mindone.diffusers.AutoencoderKLAllegro.encode(x, return_dict=False)

Encode a batch of videos into latents.

PARAMETER DESCRIPTION
x

Input batch of videos.

TYPE: `ms.Tensor`

return_dict

Whether to return a [~models.autoencoder_kl.AutoencoderKLOutput] instead of a plain tuple.

TYPE: `bool`, defaults to `False` DEFAULT: False

RETURNS DESCRIPTION
Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

The latent representations of the encoded videos. If return_dict is True, a

Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]

[~models.autoencoder_kl.AutoencoderKLOutput] is returned, otherwise a plain tuple is returned.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
def encode(
    self, x: ms.Tensor, return_dict: bool = False
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
    r"""
    Encode a batch of videos into latents.

    Args:
        x (`ms.Tensor`):
            Input batch of videos.
        return_dict (`bool`, defaults to `False`):
            Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

    Returns:
            The latent representations of the encoded videos. If `return_dict` is True, a
            [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
    """
    if self.use_slicing and x.shape[0] > 1:
        encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
        h = ops.cat(encoded_slices)
    else:
        h = self._encode(x)

    # we cannot use class in graph mode, even for jit_class or subclass of Tensor. :-(
    # posterior = DiagonalGaussianDistribution(h)

    if not return_dict:
        return (h,)
    return AutoencoderKLOutput(latent_dist=h)

mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput dataclass

Bases: BaseOutput

Output of AutoencoderKL encoding method.

PARAMETER DESCRIPTION
latent

Encoded outputs of Encoder represented as the mean and logvar of DiagonalGaussianDistribution. DiagonalGaussianDistribution allows for sampling latents from the distribution.

TYPE: `ms.Tensor`

Source code in mindone/diffusers/models/modeling_outputs.py
 8
 9
10
11
12
13
14
15
16
17
18
19
@dataclass
class AutoencoderKLOutput(BaseOutput):
    """
    Output of AutoencoderKL encoding method.

    Args:
        latent (`ms.Tensor`):
            Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
            `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
    """

    latent: ms.Tensor

mindone.diffusers.models.autoencoders.vae.DecoderOutput dataclass

Bases: BaseOutput

Output of decoding method.

PARAMETER DESCRIPTION
sample

The decoded output sample from the last layer of the model.

TYPE: `ms.Tensor` of shape `(batch_size, num_channels, height, width)`

Source code in mindone/diffusers/models/autoencoders/vae.py
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class DecoderOutput(BaseOutput):
    r"""
    Output of decoding method.

    Args:
        sample (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`):
            The decoded output sample from the last layer of the model.
    """

    sample: ms.Tensor
    commit_loss: Optional[ms.Tensor] = None