Skip to content

Tiny AutoEncoder

Tiny AutoEncoder for Stable Diffusion (TAESD) was introduced in madebyollin/taesd by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a StableDiffusionPipeline or StableDiffusionXLPipeline almost instantly.

To use with Stable Diffusion v-2.1:

import mindspore as ms
from mindone.diffusers import DiffusionPipeline, AutoencoderTiny

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", mindspore_dtype=ms.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", mindspore_dtype=ms.float16)

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25)[0][0]
image

To use with Stable Diffusion XL 1.0

import mindspore as ms
from mindone.diffusers import DiffusionPipeline, AutoencoderTiny

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", mindspore_dtype=ms.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", mindspore_dtype=ms.float16)

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25)[0][0]
image

mindone.diffusers.AutoencoderTiny

Bases: ModelMixin, ConfigMixin

A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.

[AutoencoderTiny] is a wrapper around the original implementation of TAESD.

This model inherits from [ModelMixin]. Check the superclass documentation for its generic methods implemented for all models (such as downloading or saving).

PARAMETER DESCRIPTION
in_channels

Number of channels in the input image.

TYPE: `int`, *optional*, defaults to 3 DEFAULT: 3

out_channels

Number of channels in the output.

TYPE: `int`, *optional*, defaults to 3 DEFAULT: 3

encoder_block_out_channels

Tuple of integers representing the number of output channels for each encoder block. The length of the tuple should be equal to the number of encoder blocks.

TYPE: `Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)` DEFAULT: (64, 64, 64, 64)

decoder_block_out_channels

Tuple of integers representing the number of output channels for each decoder block. The length of the tuple should be equal to the number of decoder blocks.

TYPE: `Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)` DEFAULT: (64, 64, 64, 64)

act_fn

Activation function to be used throughout the model.

TYPE: `str`, *optional*, defaults to `"relu"` DEFAULT: 'relu'

latent_channels

Number of channels in the latent representation. The latent space acts as a compressed representation of the input image.

TYPE: `int`, *optional*, defaults to 4 DEFAULT: 4

upsampling_scaling_factor

Scaling factor for upsampling in the decoder. It determines the size of the output image during the upsampling process.

TYPE: `int`, *optional*, defaults to 2 DEFAULT: 2

num_encoder_blocks

Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The length of the tuple should be equal to the number of stages in the encoder. Each stage has a different number of encoder blocks.

TYPE: `Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)` DEFAULT: (1, 3, 3, 3)

num_decoder_blocks

Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The length of the tuple should be equal to the number of stages in the decoder. Each stage has a different number of decoder blocks.

TYPE: `Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)` DEFAULT: (3, 3, 3, 1)

latent_magnitude

Magnitude of the latent representation. This parameter scales the latent representation values to control the extent of information preservation.

TYPE: `float`, *optional*, defaults to 3.0 DEFAULT: 3

latent_shift

Shift applied to the latent representation. This parameter controls the center of the latent space.

TYPE: float, *optional*, defaults to 0.5 DEFAULT: 0.5

scaling_factor

The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula z = z * scaling_factor before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: z = 1 / scaling_factor * z. For more details, refer to sections 4.3.2 and D.1 of the High-Resolution Image Synthesis with Latent Diffusion Models paper. For this Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default.

TYPE: `float`, *optional*, defaults to 1.0 DEFAULT: 1.0

force_upcast

If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without losing too much precision, in which case force_upcast can be set to False (see this fp16-friendly AutoEncoder).

TYPE: `bool`, *optional*, default to `False` DEFAULT: False

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
class AutoencoderTiny(ModelMixin, ConfigMixin):
    r"""
    A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.

    [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
    all models (such as downloading or saving).

    Parameters:
        in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
        out_channels (`int`,  *optional*, defaults to 3): Number of channels in the output.
        encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
            Tuple of integers representing the number of output channels for each encoder block. The length of the
            tuple should be equal to the number of encoder blocks.
        decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
            Tuple of integers representing the number of output channels for each decoder block. The length of the
            tuple should be equal to the number of decoder blocks.
        act_fn (`str`, *optional*, defaults to `"relu"`):
            Activation function to be used throughout the model.
        latent_channels (`int`, *optional*, defaults to 4):
            Number of channels in the latent representation. The latent space acts as a compressed representation of
            the input image.
        upsampling_scaling_factor (`int`, *optional*, defaults to 2):
            Scaling factor for upsampling in the decoder. It determines the size of the output image during the
            upsampling process.
        num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
            Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
            length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
            number of encoder blocks.
        num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
            Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
            length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
            number of decoder blocks.
        latent_magnitude (`float`, *optional*, defaults to 3.0):
            Magnitude of the latent representation. This parameter scales the latent representation values to control
            the extent of information preservation.
        latent_shift (float, *optional*, defaults to 0.5):
            Shift applied to the latent representation. This parameter controls the center of the latent space.
        scaling_factor (`float`, *optional*, defaults to 1.0):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
            however, no such scaling factor was used, hence the value of 1.0 as the default.
        force_upcast (`bool`, *optional*, default to `False`):
            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
            can be fine-tuned / trained to a lower range without losing too much precision, in which case
            `force_upcast` can be set to `False` (see this fp16-friendly
            [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
        decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
        act_fn: str = "relu",
        upsample_fn: str = "nearest",
        latent_channels: int = 4,
        upsampling_scaling_factor: int = 2,
        num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
        num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
        latent_magnitude: int = 3,
        latent_shift: float = 0.5,
        force_upcast: bool = False,
        scaling_factor: float = 1.0,
        shift_factor: float = 0.0,
    ):
        super().__init__()

        if len(encoder_block_out_channels) != len(num_encoder_blocks):
            raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
        if len(decoder_block_out_channels) != len(num_decoder_blocks):
            raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")

        self.encoder = EncoderTiny(
            in_channels=in_channels,
            out_channels=latent_channels,
            num_blocks=num_encoder_blocks,
            block_out_channels=encoder_block_out_channels,
            act_fn=act_fn,
        )

        self.decoder = DecoderTiny(
            in_channels=latent_channels,
            out_channels=out_channels,
            num_blocks=num_decoder_blocks,
            block_out_channels=decoder_block_out_channels,
            upsampling_scaling_factor=upsampling_scaling_factor,
            act_fn=act_fn,
            upsample_fn=upsample_fn,
        )

        self.latent_magnitude = latent_magnitude
        self.latent_shift = latent_shift
        self.scaling_factor = scaling_factor

        self.use_slicing = False
        self.use_tiling = False

        # only relevant if vae tiling is enabled
        self.spatial_scale_factor = 2**out_channels
        self.tile_overlap_factor = 0.125
        self.tile_sample_min_size = 512
        self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor

        self.register_to_config(block_out_channels=decoder_block_out_channels)
        self.register_to_config(force_upcast=False)

    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        if isinstance(module, (EncoderTiny, DecoderTiny)):
            module.gradient_checkpointing = value

    def scale_latents(self, x: ms.Tensor) -> ms.Tensor:
        """raw latents -> [0, 1]"""
        return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)

    def unscale_latents(self, x: ms.Tensor) -> ms.Tensor:
        """[0, 1] -> raw latents"""
        return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)

    def enable_slicing(self) -> None:
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self) -> None:
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    def enable_tiling(self, use_tiling: bool = True) -> None:
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
        """
        self.use_tiling = use_tiling

    def disable_tiling(self) -> None:
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.enable_tiling(False)

    def _tiled_encode(self, x: ms.Tensor) -> ms.Tensor:
        r"""Encode a batch of images using a tiled encoder.

        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
        steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
        tiles overlap and are blended together to form a smooth output.

        Args:
            x (`ms.Tensor`): Input batch of images.

        Returns:
            `ms.Tensor`: Encoded batch of images.
        """
        # scale of encoder output relative to input
        sf = self.spatial_scale_factor
        tile_size = self.tile_sample_min_size

        # number of pixels to blend and to traverse between tile
        blend_size = int(tile_size * self.tile_overlap_factor)
        traverse_size = tile_size - blend_size

        # tiles index (up/left)
        ti = range(0, x.shape[-2], traverse_size)
        tj = range(0, x.shape[-1], traverse_size)

        # mask for blending
        blend_masks = ops.stack(ops.meshgrid([ops.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij"))
        blend_masks = blend_masks.clamp(0, 1)

        # output array
        out = ops.zeros((x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf))
        for i in ti:
            for j in tj:
                tile_in = x[..., i : i + tile_size, j : j + tile_size]
                # tile result
                tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
                tile = self.encoder(tile_in)
                h, w = tile.shape[-2], tile.shape[-1]
                # blend tile result into output
                blend_mask_i = ops.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
                blend_mask_j = ops.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
                blend_mask = blend_mask_i * blend_mask_j
                tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
                tile_out = blend_mask * tile + (1 - blend_mask) * tile_out
        return out

    def _tiled_decode(self, x: ms.Tensor) -> ms.Tensor:
        r"""Encode a batch of images using a tiled encoder.

        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
        steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
        tiles overlap and are blended together to form a smooth output.

        Args:
            x (`ms.Tensor`): Input batch of images.

        Returns:
            `ms.Tensor`: Encoded batch of images.
        """
        # scale of decoder output relative to input
        sf = self.spatial_scale_factor
        tile_size = self.tile_latent_min_size

        # number of pixels to blend and to traverse between tiles
        blend_size = int(tile_size * self.tile_overlap_factor)
        traverse_size = tile_size - blend_size

        # tiles index (up/left)
        ti = range(0, x.shape[-2], traverse_size)
        tj = range(0, x.shape[-1], traverse_size)

        # mask for blending
        blend_masks = ops.stack(ops.meshgrid([ops.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij"))
        blend_masks = blend_masks.clamp(0, 1)

        # output array
        out = ops.zeros((x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf))
        for i in ti:
            for j in tj:
                tile_in = x[..., i : i + tile_size, j : j + tile_size]
                # tile result
                tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
                tile = self.decoder(tile_in)
                h, w = tile.shape[-2], tile.shape[-1]
                # blend tile result into output
                blend_mask_i = ops.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
                blend_mask_j = ops.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
                blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
                tile_out = blend_mask * tile + (1 - blend_mask) * tile_out
        return out

    def encode(self, x: ms.Tensor, return_dict: bool = False) -> Union[AutoencoderTinyOutput, Tuple[ms.Tensor]]:
        if self.use_slicing and x.shape[0] > 1:
            output = [
                self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
            ]
            output = ops.cat(output)
        else:
            output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)

        if not return_dict:
            return (output,)

        return AutoencoderTinyOutput(latents=output)

    def decode(
        self, x: ms.Tensor, generator: Optional[np.random.Generator] = None, return_dict: bool = False
    ) -> Union[DecoderOutput, Tuple[ms.Tensor]]:
        if self.use_slicing and x.shape[0] > 1:
            output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
            output = ops.cat(output)
        else:
            output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)

        if not return_dict:
            return (output,)

        return DecoderOutput(sample=output)

    def construct(
        self,
        sample: ms.Tensor,
        return_dict: bool = False,
    ) -> Union[DecoderOutput, Tuple[ms.Tensor]]:
        r"""
        Args:
            sample (`ms.Tensor`): Input sample.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        hidden_dtype = sample.dtype
        enc = self.encode(sample)[0]

        # scale latents to be in [0, 1], then quantize latents to a byte tensor,
        # as if we were storing the latents in an RGBA uint8 image.
        scaled_enc = ops.round(self.scale_latents(enc).mul(255)).to(ms.uint8)

        # unquantize latents back into [0, 1], then unscale latents back to their original range,
        # as if we were loading the latents from an RGBA uint8 image.
        unscaled_enc = self.unscale_latents(scaled_enc / 255.0).to(hidden_dtype)

        # Keep an eye on it: it's different from diffusers: ...[0]
        dec = self.decode(unscaled_enc)[0]

        if not return_dict:
            return (dec,)
        return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderTiny.construct(sample, return_dict=False)

PARAMETER DESCRIPTION
sample

Input sample.

TYPE: `ms.Tensor`

return_dict

Whether or not to return a [DecoderOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def construct(
    self,
    sample: ms.Tensor,
    return_dict: bool = False,
) -> Union[DecoderOutput, Tuple[ms.Tensor]]:
    r"""
    Args:
        sample (`ms.Tensor`): Input sample.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
    """
    hidden_dtype = sample.dtype
    enc = self.encode(sample)[0]

    # scale latents to be in [0, 1], then quantize latents to a byte tensor,
    # as if we were storing the latents in an RGBA uint8 image.
    scaled_enc = ops.round(self.scale_latents(enc).mul(255)).to(ms.uint8)

    # unquantize latents back into [0, 1], then unscale latents back to their original range,
    # as if we were loading the latents from an RGBA uint8 image.
    unscaled_enc = self.unscale_latents(scaled_enc / 255.0).to(hidden_dtype)

    # Keep an eye on it: it's different from diffusers: ...[0]
    dec = self.decode(unscaled_enc)[0]

    if not return_dict:
        return (dec,)
    return DecoderOutput(sample=dec)

mindone.diffusers.AutoencoderTiny.disable_slicing()

Disable sliced VAE decoding. If enable_slicing was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
178
179
180
181
182
183
def disable_slicing(self) -> None:
    r"""
    Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.use_slicing = False

mindone.diffusers.AutoencoderTiny.disable_tiling()

Disable tiled VAE decoding. If enable_tiling was previously enabled, this method will go back to computing decoding in one step.

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
193
194
195
196
197
198
def disable_tiling(self) -> None:
    r"""
    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
    decoding in one step.
    """
    self.enable_tiling(False)

mindone.diffusers.AutoencoderTiny.enable_slicing()

Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
171
172
173
174
175
176
def enable_slicing(self) -> None:
    r"""
    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
    """
    self.use_slicing = True

mindone.diffusers.AutoencoderTiny.enable_tiling(use_tiling=True)

Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images.

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
185
186
187
188
189
190
191
def enable_tiling(self, use_tiling: bool = True) -> None:
    r"""
    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
    processing larger images.
    """
    self.use_tiling = use_tiling

mindone.diffusers.AutoencoderTiny.scale_latents(x)

raw latents -> [0, 1]

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
163
164
165
def scale_latents(self, x: ms.Tensor) -> ms.Tensor:
    """raw latents -> [0, 1]"""
    return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)

mindone.diffusers.AutoencoderTiny.unscale_latents(x)

[0, 1] -> raw latents

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
167
168
169
def unscale_latents(self, x: ms.Tensor) -> ms.Tensor:
    """[0, 1] -> raw latents"""
    return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)

mindone.diffusers.models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput dataclass

Bases: BaseOutput

Output of AutoencoderTiny encoding method.

PARAMETER DESCRIPTION
latents

Encoded outputs of the Encoder.

TYPE: `ms.Tensor`

Source code in mindone/diffusers/models/autoencoders/autoencoder_tiny.py
30
31
32
33
34
35
36
37
38
39
40
@dataclass
class AutoencoderTinyOutput(BaseOutput):
    """
    Output of AutoencoderTiny encoding method.

    Args:
        latents (`ms.Tensor`): Encoded outputs of the `Encoder`.

    """

    latents: ms.Tensor