Skip to content

AutoencoderKLCosmos

Cosmos Tokenizers.

Supported models: - nvidia/Cosmos-1.0-Tokenizer-CV8x8x8

The model can be loaded with the following code snippet.

from mindone.diffusers import AutoencoderKLCosmos

vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")

mindone.diffusers.models.autoencoders.autoencoder_kl_cosmos.AutoencoderKLCosmos

Bases: ModelMixin, ConfigMixin

Autoencoder used in Cosmos.

PARAMETER DESCRIPTION
in_channels

Number of input channels.

TYPE: `int`, defaults to `3` DEFAULT: 3

out_channels

Number of output channels.

TYPE: `int`, defaults to `3` DEFAULT: 3

latent_channels

Number of latent channels.

TYPE: `int`, defaults to `16` DEFAULT: 16

encoder_block_out_channels

Number of output channels for each encoder down block.

TYPE: `Tuple[int, ...]`, defaults to `(128, 256, 512, 512)` DEFAULT: (128, 256, 512, 512)

decode_block_out_channels

Number of output channels for each decoder up block.

TYPE: `Tuple[int, ...]`, defaults to `(256, 512, 512, 512)` DEFAULT: (256, 512, 512, 512)

attention_resolutions

List of image/video resolutions at which to apply attention.

TYPE: `Tuple[int, ...]`, defaults to `(32,)` DEFAULT: (32,)

resolution

Base image/video resolution used for computing whether a block should have attention layers.

TYPE: `int`, defaults to `1024` DEFAULT: 1024

num_layers

Number of resnet blocks in each encoder/decoder block.

TYPE: `int`, defaults to `2` DEFAULT: 2

patch_size

Patch size used for patching the input image/video.

TYPE: `int`, defaults to `4` DEFAULT: 4

patch_type

Patch type used for patching the input image/video. Can be either haar or rearrange.

TYPE: `str`, defaults to `haar` DEFAULT: 'haar'

scaling_factor

The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula z = z * scaling_factor before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: z = 1 / scaling_factor * z. For more details, refer to sections 4.3.2 and D.1 of the High-Resolution Image Synthesis with Latent Diffusion Models paper. Not applicable in Cosmos, but we default to 1.0 for consistency.

TYPE: `float`, defaults to `1.0` DEFAULT: 1.0

spatial_compression_ratio

The spatial compression ratio to apply in the VAE. The number of downsample blocks is determined using this.

TYPE: `int`, defaults to `8` DEFAULT: 8

temporal_compression_ratio

The temporal compression ratio to apply in the VAE. The number of downsample blocks is determined using this.

TYPE: `int`, defaults to `8` DEFAULT: 8

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
    r"""
    Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).

    Args:
        in_channels (`int`, defaults to `3`):
            Number of input channels.
        out_channels (`int`, defaults to `3`):
            Number of output channels.
        latent_channels (`int`, defaults to `16`):
            Number of latent channels.
        encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
            Number of output channels for each encoder down block.
        decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
            Number of output channels for each decoder up block.
        attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
            List of image/video resolutions at which to apply attention.
        resolution (`int`, defaults to `1024`):
            Base image/video resolution used for computing whether a block should have attention layers.
        num_layers (`int`, defaults to `2`):
            Number of resnet blocks in each encoder/decoder block.
        patch_size (`int`, defaults to `4`):
            Patch size used for patching the input image/video.
        patch_type (`str`, defaults to `haar`):
            Patch type used for patching the input image/video. Can be either `haar` or `rearrange`.
        scaling_factor (`float`, defaults to `1.0`):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. Not applicable in
            Cosmos, but we default to 1.0 for consistency.
        spatial_compression_ratio (`int`, defaults to `8`):
            The spatial compression ratio to apply in the VAE. The number of downsample blocks is determined using
            this.
        temporal_compression_ratio (`int`, defaults to `8`):
            The temporal compression ratio to apply in the VAE. The number of downsample blocks is determined using
            this.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        latent_channels: int = 16,
        encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
        decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
        attention_resolutions: Tuple[int, ...] = (32,),
        resolution: int = 1024,
        num_layers: int = 2,
        patch_size: int = 4,
        patch_type: str = "haar",
        scaling_factor: float = 1.0,
        spatial_compression_ratio: int = 8,
        temporal_compression_ratio: int = 8,
        latents_mean: Optional[List[float]] = LATENTS_MEAN,
        latents_std: Optional[List[float]] = LATENTS_STD,
    ) -> None:
        super().__init__()

        self.encoder = CosmosEncoder3d(
            in_channels=in_channels,
            out_channels=latent_channels,
            block_out_channels=encoder_block_out_channels,
            num_resnet_blocks=num_layers,
            attention_resolutions=attention_resolutions,
            resolution=resolution,
            patch_size=patch_size,
            patch_type=patch_type,
            spatial_compression_ratio=spatial_compression_ratio,
            temporal_compression_ratio=temporal_compression_ratio,
        )
        self.decoder = CosmosDecoder3d(
            in_channels=latent_channels,
            out_channels=out_channels,
            block_out_channels=decode_block_out_channels,
            num_resnet_blocks=num_layers,
            attention_resolutions=attention_resolutions,
            resolution=resolution,
            patch_size=patch_size,
            patch_type=patch_type,
            spatial_compression_ratio=spatial_compression_ratio,
            temporal_compression_ratio=temporal_compression_ratio,
        )

        self.quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
        self.post_quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
        self.identity_dist = IdentityDistribution()

        # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
        # to perform decoding of a single video latent at a time.
        self.use_slicing = False

        # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
        # frames spatially into smaller tiles and performing multiple construct passes for decoding, and then blending the
        # intermediate tiles together, the memory requirement can be lowered.
        self.use_tiling = False

        # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
        # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
        self.use_framewise_encoding = False
        self.use_framewise_decoding = False

        # This can be configured based on the amount of GPU memory available.
        # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
        # Setting it to higher values results in higher memory usage.
        self.num_sample_frames_batch_size = 16
        self.num_latent_frames_batch_size = 2

        # The minimal tile height and width for spatial tiling to be used
        self.tile_sample_min_height = 512
        self.tile_sample_min_width = 512
        self.tile_sample_min_num_frames = 16

        # The minimal distance between two spatial tiles
        self.tile_sample_stride_height = 448
        self.tile_sample_stride_width = 448
        self.tile_sample_stride_num_frames = 8

    def enable_tiling(
        self,
        tile_sample_min_height: Optional[int] = None,
        tile_sample_min_width: Optional[int] = None,
        tile_sample_min_num_frames: Optional[int] = None,
        tile_sample_stride_height: Optional[float] = None,
        tile_sample_stride_width: Optional[float] = None,
        tile_sample_stride_num_frames: Optional[float] = None,
    ) -> None:
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.

        Args:
            tile_sample_min_height (`int`, *optional*):
                The minimum height required for a sample to be separated into tiles across the height dimension.
            tile_sample_min_width (`int`, *optional*):
                The minimum width required for a sample to be separated into tiles across the width dimension.
            tile_sample_stride_height (`int`, *optional*):
                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
                no tiling artifacts produced across the height dimension.
            tile_sample_stride_width (`int`, *optional*):
                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
                artifacts produced across the width dimension.
        """
        self.use_tiling = True
        self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
        self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
        self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
        self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
        self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

    def disable_tiling(self) -> None:
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_tiling = False

    def enable_slicing(self) -> None:
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self) -> None:
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    def _encode(self, x: ms.tensor) -> ms.tensor:
        x = self.encoder(x)
        enc = self.quant_conv(x)
        return enc

    def encode(self, x: ms.tensor, return_dict: bool = False) -> ms.tensor:
        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            h = mint.cat(encoded_slices)
        else:
            h = self._encode(x)

        if not return_dict:
            return (h,)
        return AutoencoderKLOutput(latent_dist=h)

    def _decode(self, z: ms.tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.tensor]]:
        z = self.post_quant_conv(z)
        dec = self.decoder(z)

        if not return_dict:
            return (dec,)
        return DecoderOutput(sample=dec)

    def decode(self, z: ms.tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.tensor]]:
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice)[0] for z_slice in z.split(1)]
            decoded = mint.cat(decoded_slices)
        else:
            decoded = self._decode(z)[0]

        if not return_dict:
            return (decoded,)
        return DecoderOutput(sample=decoded)

    def construct(
        self,
        sample: ms.tensor,
        sample_posterior: bool = False,
        return_dict: bool = False,
        generator: Optional[np.random.Generator] = None,
    ) -> Union[Tuple[ms.tensor], DecoderOutput]:
        x = sample
        posterior = self.encode(x)[0]
        if sample_posterior:
            z = self.identity_dist.sample(posterior, generator=generator)
        else:
            z = self.identity_dist.mode(posterior)
        dec = self.decode(z)[0]
        if not return_dict:
            return (dec,)
        return DecoderOutput(sample=dec)

mindone.diffusers.models.autoencoders.autoencoder_kl_cosmos.AutoencoderKLCosmos.disable_slicing()

Disable sliced VAE decoding. If enable_slicing was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
1087
1088
1089
1090
1091
1092
def disable_slicing(self) -> None:
    r"""
    Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_slicing = False

mindone.diffusers.models.autoencoders.autoencoder_kl_cosmos.AutoencoderKLCosmos.disable_tiling()

Disable tiled VAE decoding. If enable_tiling was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
1073
1074
1075
1076
1077
1078
def disable_tiling(self) -> None:
    r"""
    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_tiling = False

mindone.diffusers.models.autoencoders.autoencoder_kl_cosmos.AutoencoderKLCosmos.enable_slicing()

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
1080
1081
1082
1083
1084
1085
def enable_slicing(self) -> None:
    r"""
    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.use_slicing = True

mindone.diffusers.models.autoencoders.autoencoder_kl_cosmos.AutoencoderKLCosmos.enable_tiling(tile_sample_min_height=None, tile_sample_min_width=None, tile_sample_min_num_frames=None, tile_sample_stride_height=None, tile_sample_stride_width=None, tile_sample_stride_num_frames=None)

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

PARAMETER DESCRIPTION
tile_sample_min_height

The minimum height required for a sample to be separated into tiles across the height dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_min_width

The minimum width required for a sample to be separated into tiles across the width dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_stride_height

The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are no tiling artifacts produced across the height dimension.

TYPE: `int`, *optional* DEFAULT: None

tile_sample_stride_width

The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling artifacts produced across the width dimension.

TYPE: `int`, *optional* DEFAULT: None

Source code in mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
def enable_tiling(
    self,
    tile_sample_min_height: Optional[int] = None,
    tile_sample_min_width: Optional[int] = None,
    tile_sample_min_num_frames: Optional[int] = None,
    tile_sample_stride_height: Optional[float] = None,
    tile_sample_stride_width: Optional[float] = None,
    tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
    r"""
    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
    processing larger images.

    Args:
        tile_sample_min_height (`int`, *optional*):
            The minimum height required for a sample to be separated into tiles across the height dimension.
        tile_sample_min_width (`int`, *optional*):
            The minimum width required for a sample to be separated into tiles across the width dimension.
        tile_sample_stride_height (`int`, *optional*):
            The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
            no tiling artifacts produced across the height dimension.
        tile_sample_stride_width (`int`, *optional*):
            The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
            artifacts produced across the width dimension.
    """
    self.use_tiling = True
    self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
    self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
    self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
    self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
    self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
    self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput dataclass

Bases: BaseOutput

Output of AutoencoderKL encoding method.

PARAMETER DESCRIPTION
latent

Encoded outputs of Encoder represented as the mean and logvar of DiagonalGaussianDistribution. DiagonalGaussianDistribution allows for sampling latents from the distribution.

TYPE: `ms.Tensor`

Source code in mindone/diffusers/models/modeling_outputs.py
10
11
12
13
14
15
16
17
18
19
20
21
@dataclass
class AutoencoderKLOutput(BaseOutput):
    """
    Output of AutoencoderKL encoding method.

    Args:
        latent (`ms.Tensor`):
            Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
            `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
    """

    latent: ms.Tensor

mindone.diffusers.models.autoencoders.vae.DecoderOutput dataclass

Bases: BaseOutput

Output of decoding method.

PARAMETER DESCRIPTION
sample

The decoded output sample from the last layer of the model.

TYPE: `ms.Tensor` of shape `(batch_size, num_channels, height, width)`

Source code in mindone/diffusers/models/autoencoders/vae.py
46
47
48
49
50
51
52
53
54
55
56
57
@dataclass
class DecoderOutput(BaseOutput):
    r"""
    Output of decoding method.

    Args:
        sample (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`):
            The decoded output sample from the last layer of the model.
    """

    sample: ms.Tensor
    commit_loss: Optional[ms.Tensor] = None