Skip to content

Latent Diffusion

Latent Diffusion was proposed in High-Resolution Image Synthesis with Latent Diffusion Models by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.

The abstract from the paper is:

By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.

The original codebase can be found at CompVis/latent-diffusion.

Tip

Make sure to check out the Schedulers guide to learn how to explore the tradeoff between scheduler speed and quality, and see the reuse components across pipelines section to learn how to efficiently load the same components into multiple pipelines.

mindone.diffusers.LDMTextToImagePipeline

Bases: DiffusionPipeline

Pipeline for text-to-image generation using latent diffusion.

This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

PARAMETER DESCRIPTION
vqvae

Vector-quantized (VQ) model to encode and decode images to and from latent representations.

TYPE: [`VQModel`]

bert

Text-encoder model based on [~transformers.BERT].

TYPE: [`LDMBertModel`]

tokenizer

A BertTokenizer to tokenize text.

TYPE: [`~transformers.BertTokenizer`]

unet

A UNet2DConditionModel to denoise the encoded image latents.

TYPE: [`UNet2DConditionModel`]

scheduler

A scheduler to be used in combination with unet to denoise the encoded image latents. Can be one of [DDIMScheduler], [LMSDiscreteScheduler], or [PNDMScheduler].

TYPE: [`SchedulerMixin`]

Source code in mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class LDMTextToImagePipeline(DiffusionPipeline):
    r"""
    Pipeline for text-to-image generation using latent diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Parameters:
        vqvae ([`VQModel`]):
            Vector-quantized (VQ) model to encode and decode images to and from latent representations.
        bert ([`LDMBertModel`]):
            Text-encoder model based on [`~transformers.BERT`].
        tokenizer ([`~transformers.BertTokenizer`]):
            A `BertTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
    """

    model_cpu_offload_seq = "bert->unet->vqvae"

    def __init__(
        self,
        vqvae: Union[VQModel, AutoencoderKL],
        bert: MSPreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        unet: Union[UNet2DModel, UNet2DConditionModel],
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
    ):
        super().__init__()
        self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
        self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)

    def __call__(
        self,
        prompt: Union[str, List[str]],
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: Optional[int] = 50,
        guidance_scale: Optional[float] = 1.0,
        eta: Optional[float] = 0.0,
        generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
        latents: Optional[ms.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = False,
        **kwargs,
    ) -> Union[Tuple, ImagePipelineOutput]:
        r"""
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`):
                The prompt or prompts to guide the image generation.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 1.0):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            generator (`np.random.Generator`, *optional*):
                A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
                generation deterministic.
            latents (`ms.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.

        Example:

        ```py
        >>> from mindone.diffusers import DiffusionPipeline

        >>> # load model and scheduler
        >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

        >>> # run pipeline in inference (sample random noise and denoise)
        >>> prompt = "A painting of a squirrel eating a burger"
        >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)[0]

        >>> # save images
        >>> for idx, image in enumerate(images):
        ...     image.save(f"squirrel-{idx}.png")
        ```

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        """
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # get unconditional embeddings for classifier free guidance
        if guidance_scale != 1.0:
            uncond_input = self.tokenizer(
                [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="np"
            )
            uncond_input_ids = ms.Tensor(uncond_input.input_ids)
            negative_prompt_embeds = self.bert(uncond_input_ids)[0]

        # get prompt text embeddings
        text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="np")
        text_input_ids = ms.Tensor(text_input.input_ids)
        prompt_embeds = self.bert(text_input_ids)[0]

        # get the initial random noise unless the user supplied it
        latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(latents_shape, generator=generator, dtype=prompt_embeds.dtype)
        else:
            if latents.shape != latents_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

        self.scheduler.set_timesteps(num_inference_steps)

        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())

        extra_kwargs = {}
        if accepts_eta:
            extra_kwargs["eta"] = eta

        for t in self.progress_bar(self.scheduler.timesteps):
            if guidance_scale == 1.0:
                # guidance_scale of 1 means no guidance
                latents_input = latents
                context = prompt_embeds
            else:
                # For classifier free guidance, we need to do two forward passes.
                # Here we concatenate the unconditional and text embeddings into a single batch
                # to avoid doing two forward passes
                latents_input = ops.cat([latents] * 2)
                context = ops.cat([negative_prompt_embeds, prompt_embeds])

            # predict the noise residual
            noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)[0]
            # perform guidance
            if guidance_scale != 1.0:
                noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0]

        # scale and decode the image latents with vae
        latents = 1 / self.vqvae.config.scaling_factor * latents
        image = self.vqvae.decode(latents)[0]

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.permute(0, 2, 3, 1).asnumpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

mindone.diffusers.LDMTextToImagePipeline.__call__(prompt, height=None, width=None, num_inference_steps=50, guidance_scale=1.0, eta=0.0, generator=None, latents=None, output_type='pil', return_dict=False, **kwargs)

The call function to the pipeline for generation.

PARAMETER DESCRIPTION
prompt

The prompt or prompts to guide the image generation.

TYPE: `str` or `List[str]`

height

The height in pixels of the generated image.

TYPE: `int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor` DEFAULT: None

width

The width in pixels of the generated image.

TYPE: `int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor` DEFAULT: None

num_inference_steps

The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

TYPE: `int`, *optional*, defaults to 50 DEFAULT: 50

guidance_scale

A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when guidance_scale > 1.

TYPE: `float`, *optional*, defaults to 1.0 DEFAULT: 1.0

generator

A np.random.Generator to make generation deterministic.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

latents

Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random generator.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

output_type

The output format of the generated image. Choose between PIL.Image or np.array.

TYPE: `str`, *optional*, defaults to `"pil"` DEFAULT: 'pil'

return_dict

Whether or not to return a [ImagePipelineOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

>>> from mindone.diffusers import DiffusionPipeline

>>> # load model and scheduler
>>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

>>> # run pipeline in inference (sample random noise and denoise)
>>> prompt = "A painting of a squirrel eating a burger"
>>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)[0]

>>> # save images
>>> for idx, image in enumerate(images):
...     image.save(f"squirrel-{idx}.png")
RETURNS DESCRIPTION
Union[Tuple, ImagePipelineOutput]

[~pipelines.ImagePipelineOutput] or tuple: If return_dict is True, [~pipelines.ImagePipelineOutput] is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Source code in mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __call__(
    self,
    prompt: Union[str, List[str]],
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: Optional[int] = 50,
    guidance_scale: Optional[float] = 1.0,
    eta: Optional[float] = 0.0,
    generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
    latents: Optional[ms.Tensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = False,
    **kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
    r"""
    The call function to the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`):
            The prompt or prompts to guide the image generation.
        height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        guidance_scale (`float`, *optional*, defaults to 1.0):
            A higher guidance scale value encourages the model to generate images closely linked to the text
            `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
        generator (`np.random.Generator`, *optional*):
            A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
            generation deterministic.
        latents (`ms.Tensor`, *optional*):
            Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor is generated by sampling using the supplied random `generator`.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.

    Example:

    ```py
    >>> from mindone.diffusers import DiffusionPipeline

    >>> # load model and scheduler
    >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

    >>> # run pipeline in inference (sample random noise and denoise)
    >>> prompt = "A painting of a squirrel eating a burger"
    >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)[0]

    >>> # save images
    >>> for idx, image in enumerate(images):
    ...     image.save(f"squirrel-{idx}.png")
    ```

    Returns:
        [`~pipelines.ImagePipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
            returned where the first element is a list with the generated images.
    """
    # 0. Default height and width to unet
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    if isinstance(prompt, str):
        batch_size = 1
    elif isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

    if height % 8 != 0 or width % 8 != 0:
        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

    # get unconditional embeddings for classifier free guidance
    if guidance_scale != 1.0:
        uncond_input = self.tokenizer(
            [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="np"
        )
        uncond_input_ids = ms.Tensor(uncond_input.input_ids)
        negative_prompt_embeds = self.bert(uncond_input_ids)[0]

    # get prompt text embeddings
    text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="np")
    text_input_ids = ms.Tensor(text_input.input_ids)
    prompt_embeds = self.bert(text_input_ids)[0]

    # get the initial random noise unless the user supplied it
    latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(latents_shape, generator=generator, dtype=prompt_embeds.dtype)
    else:
        if latents.shape != latents_shape:
            raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

    self.scheduler.set_timesteps(num_inference_steps)

    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())

    extra_kwargs = {}
    if accepts_eta:
        extra_kwargs["eta"] = eta

    for t in self.progress_bar(self.scheduler.timesteps):
        if guidance_scale == 1.0:
            # guidance_scale of 1 means no guidance
            latents_input = latents
            context = prompt_embeds
        else:
            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            latents_input = ops.cat([latents] * 2)
            context = ops.cat([negative_prompt_embeds, prompt_embeds])

        # predict the noise residual
        noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)[0]
        # perform guidance
        if guidance_scale != 1.0:
            noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0]

    # scale and decode the image latents with vae
    latents = 1 / self.vqvae.config.scaling_factor * latents
    image = self.vqvae.decode(latents)[0]

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.permute(0, 2, 3, 1).asnumpy()
    if output_type == "pil":
        image = self.numpy_to_pil(image)

    if not return_dict:
        return (image,)

    return ImagePipelineOutput(images=image)

mindone.diffusers.LDMSuperResolutionPipeline

Bases: DiffusionPipeline

A pipeline for image super-resolution using latent diffusion.

This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

PARAMETER DESCRIPTION
vqvae

Vector-quantized (VQ) model to encode and decode images to and from latent representations.

TYPE: [`VQModel`]

unet

A UNet2DModel to denoise the encoded image.

TYPE: [`UNet2DModel`]

scheduler

A scheduler to be used in combination with unet to denoise the encoded image latens. Can be one of [DDIMScheduler], [LMSDiscreteScheduler], [EulerDiscreteScheduler], [EulerAncestralDiscreteScheduler], [DPMSolverMultistepScheduler], or [PNDMScheduler].

TYPE: [`SchedulerMixin`]

Source code in mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class LDMSuperResolutionPipeline(DiffusionPipeline):
    r"""
    A pipeline for image super-resolution using latent diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Parameters:
        vqvae ([`VQModel`]):
            Vector-quantized (VQ) model to encode and decode images to and from latent representations.
        unet ([`UNet2DModel`]):
            A `UNet2DModel` to denoise the encoded image.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
            [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
    """

    def __init__(
        self,
        vqvae: VQModel,
        unet: UNet2DModel,
        scheduler: Union[
            DDIMScheduler,
            PNDMScheduler,
            LMSDiscreteScheduler,
            EulerDiscreteScheduler,
            EulerAncestralDiscreteScheduler,
            DPMSolverMultistepScheduler,
        ],
    ):
        super().__init__()
        self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)

    def __call__(
        self,
        image: Union[ms.Tensor, PIL.Image.Image] = None,
        batch_size: Optional[int] = 1,
        num_inference_steps: Optional[int] = 100,
        eta: Optional[float] = 0.0,
        generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = False,
    ) -> Union[Tuple, ImagePipelineOutput]:
        r"""
        The call function to the pipeline for generation.

        Args:
            image (`ms.Tensor` or `PIL.Image.Image`):
                `Image` or tensor representing an image batch to be used as the starting point for the process.
            batch_size (`int`, *optional*, defaults to 1):
                Number of images to generate.
            num_inference_steps (`int`, *optional*, defaults to 100):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*):
                A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
                generation deterministic.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.

        Example:

        ```py
        >>> import requests
        >>> from PIL import Image
        >>> from io import BytesIO
        >>> from mindone.diffusers import LDMSuperResolutionPipeline
        >>> import mindspore as ms

        >>> # load model and scheduler
        >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages")

        >>> # let's download an  image
        >>> url = (
        ...     "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"
        ... )
        >>> response = requests.get(url)
        >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
        >>> low_res_img = low_res_img.resize((128, 128))

        >>> # run pipeline in inference (sample random noise and denoise)
        >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1)[0][0]
        >>> # save image
        >>> upscaled_image.save("ldm_generated_image.png")
        ```

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        if isinstance(image, PIL.Image.Image):
            batch_size = 1
        elif isinstance(image, ms.Tensor):
            batch_size = image.shape[0]
        else:
            raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `ms.Tensor` but is {type(image)}")

        if isinstance(image, PIL.Image.Image):
            image = preprocess(image)

        height, width = image.shape[-2:]

        # in_channels should be 6: 3 for latents, 3 for low resolution image
        latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
        latents_dtype = next(self.unet.get_parameters()).dtype

        latents = randn_tensor(latents_shape, generator=generator, dtype=latents_dtype)

        image = image.to(dtype=latents_dtype)

        # set timesteps and move to the correct device
        self.scheduler.set_timesteps(num_inference_steps)
        timesteps_tensor = self.scheduler.timesteps

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma

        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_kwargs = {}
        if accepts_eta:
            extra_kwargs["eta"] = eta

        for t in self.progress_bar(timesteps_tensor):
            # concat latents and low resolution image in the channel dimension.
            latents_input = ops.cat([latents, image], axis=1)
            latents_input = self.scheduler.scale_model_input(latents_input, t)
            # predict the noise residual
            noise_pred = self.unet(latents_input, t)[0]
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0]

        # decode the image latents with the VQVAE
        image = self.vqvae.decode(latents)[0]
        image = ops.clamp(image, -1.0, 1.0)
        image = image / 2 + 0.5
        image = image.permute(0, 2, 3, 1).asnumpy()

        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

mindone.diffusers.LDMSuperResolutionPipeline.__call__(image=None, batch_size=1, num_inference_steps=100, eta=0.0, generator=None, output_type='pil', return_dict=False)

The call function to the pipeline for generation.

PARAMETER DESCRIPTION
image

Image or tensor representing an image batch to be used as the starting point for the process.

TYPE: `ms.Tensor` or `PIL.Image.Image` DEFAULT: None

batch_size

Number of images to generate.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

num_inference_steps

The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

TYPE: `int`, *optional*, defaults to 100 DEFAULT: 100

eta

Corresponds to parameter eta (η) from the DDIM paper. Only applies to the [~schedulers.DDIMScheduler], and is ignored in other schedulers.

TYPE: `float`, *optional*, defaults to 0.0 DEFAULT: 0.0

generator

A np.random.Generator to make generation deterministic.

TYPE: `np.random.Generator` or `List[np.random.Generator]`, *optional* DEFAULT: None

output_type

The output format of the generated image. Choose between PIL.Image or np.array.

TYPE: `str`, *optional*, defaults to `"pil"` DEFAULT: 'pil'

return_dict

Whether or not to return a [ImagePipelineOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> from mindone.diffusers import LDMSuperResolutionPipeline
>>> import mindspore as ms

>>> # load model and scheduler
>>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages")

>>> # let's download an  image
>>> url = (
...     "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"
... )
>>> response = requests.get(url)
>>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
>>> low_res_img = low_res_img.resize((128, 128))

>>> # run pipeline in inference (sample random noise and denoise)
>>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1)[0][0]
>>> # save image
>>> upscaled_image.save("ldm_generated_image.png")
RETURNS DESCRIPTION
Union[Tuple, ImagePipelineOutput]

[~pipelines.ImagePipelineOutput] or tuple: If return_dict is True, [~pipelines.ImagePipelineOutput] is returned, otherwise a tuple is returned where the first element is a list with the generated images

Source code in mindone/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def __call__(
    self,
    image: Union[ms.Tensor, PIL.Image.Image] = None,
    batch_size: Optional[int] = 1,
    num_inference_steps: Optional[int] = 100,
    eta: Optional[float] = 0.0,
    generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = False,
) -> Union[Tuple, ImagePipelineOutput]:
    r"""
    The call function to the pipeline for generation.

    Args:
        image (`ms.Tensor` or `PIL.Image.Image`):
            `Image` or tensor representing an image batch to be used as the starting point for the process.
        batch_size (`int`, *optional*, defaults to 1):
            Number of images to generate.
        num_inference_steps (`int`, *optional*, defaults to 100):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
        generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*):
            A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
            generation deterministic.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.

    Example:

    ```py
    >>> import requests
    >>> from PIL import Image
    >>> from io import BytesIO
    >>> from mindone.diffusers import LDMSuperResolutionPipeline
    >>> import mindspore as ms

    >>> # load model and scheduler
    >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages")

    >>> # let's download an  image
    >>> url = (
    ...     "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"
    ... )
    >>> response = requests.get(url)
    >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
    >>> low_res_img = low_res_img.resize((128, 128))

    >>> # run pipeline in inference (sample random noise and denoise)
    >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1)[0][0]
    >>> # save image
    >>> upscaled_image.save("ldm_generated_image.png")
    ```

    Returns:
        [`~pipelines.ImagePipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
            returned where the first element is a list with the generated images
    """
    if isinstance(image, PIL.Image.Image):
        batch_size = 1
    elif isinstance(image, ms.Tensor):
        batch_size = image.shape[0]
    else:
        raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `ms.Tensor` but is {type(image)}")

    if isinstance(image, PIL.Image.Image):
        image = preprocess(image)

    height, width = image.shape[-2:]

    # in_channels should be 6: 3 for latents, 3 for low resolution image
    latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
    latents_dtype = next(self.unet.get_parameters()).dtype

    latents = randn_tensor(latents_shape, generator=generator, dtype=latents_dtype)

    image = image.to(dtype=latents_dtype)

    # set timesteps and move to the correct device
    self.scheduler.set_timesteps(num_inference_steps)
    timesteps_tensor = self.scheduler.timesteps

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * self.scheduler.init_noise_sigma

    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
    # and should be between [0, 1]
    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
    extra_kwargs = {}
    if accepts_eta:
        extra_kwargs["eta"] = eta

    for t in self.progress_bar(timesteps_tensor):
        # concat latents and low resolution image in the channel dimension.
        latents_input = ops.cat([latents, image], axis=1)
        latents_input = self.scheduler.scale_model_input(latents_input, t)
        # predict the noise residual
        noise_pred = self.unet(latents_input, t)[0]
        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)[0]

    # decode the image latents with the VQVAE
    image = self.vqvae.decode(latents)[0]
    image = ops.clamp(image, -1.0, 1.0)
    image = image / 2 + 0.5
    image = image.permute(0, 2, 3, 1).asnumpy()

    if output_type == "pil":
        image = self.numpy_to_pil(image)

    if not return_dict:
        return (image,)

    return ImagePipelineOutput(images=image)

mindone.diffusers.pipelines.ImagePipelineOutput dataclass

Bases: BaseOutput

Output class for image pipelines.

Source code in mindone/diffusers/pipelines/pipeline_utils.py
69
70
71
72
73
74
75
76
77
78
79
80
@dataclass
class ImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
    """

    images: Union[List[PIL.Image.Image], np.ndarray]