Skip to content

DDIM

Denoising Diffusion Implicit Models (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.

The abstract from the paper is:

Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.

The original codebase can be found at ermongroup/ddim.

mindone.diffusers.DDIMPipeline

Bases: DiffusionPipeline

Pipeline for image generation.

This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

PARAMETER DESCRIPTION
unet

A UNet2DModel to denoise the encoded image latents.

TYPE: [`UNet2DModel`]

scheduler

A scheduler to be used in combination with unet to denoise the encoded image. Can be one of [DDPMScheduler], or [DDIMScheduler].

TYPE: [`SchedulerMixin`]

Source code in mindone/diffusers/pipelines/ddim/pipeline_ddim.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class DDIMPipeline(DiffusionPipeline):
    r"""
    Pipeline for image generation.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Parameters:
        unet ([`UNet2DModel`]):
            A `UNet2DModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
            [`DDPMScheduler`], or [`DDIMScheduler`].
    """

    model_cpu_offload_seq = "unet"

    def __init__(self, unet, scheduler):
        super().__init__()

        # make sure scheduler can always be converted to DDIM
        scheduler = DDIMScheduler.from_config(scheduler.config)

        self.register_modules(unet=unet, scheduler=scheduler)

    def __call__(
        self,
        batch_size: int = 1,
        generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
        eta: float = 0.0,
        num_inference_steps: int = 50,
        use_clipped_model_output: Optional[bool] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = False,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
        The call function to the pipeline for generation.

        Args:
            batch_size (`int`, *optional*, defaults to 1):
                The number of images to generate.
            generator (`np.random.Generator`, *optional*):
                A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
                generation deterministic.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
                DDIM and `1` corresponds to DDPM.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            use_clipped_model_output (`bool`, *optional*, defaults to `None`):
                If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
                downstream to the scheduler (use `None` for schedulers which don't support this argument).
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Example:

        ```py
        >>> from mindone.diffusers import DDIMPipeline
        >>> import PIL.Image
        >>> import numpy as np

        >>> # load model and scheduler
        >>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom")

        >>> # run pipeline in inference (sample random noise and denoise)
        >>> image = pipe(eta=0.0, num_inference_steps=50)

        >>> # process image to PIL
        >>> image_processed = image.cpu().permute(0, 2, 3, 1)
        >>> image_processed = (image_processed + 1.0) * 127.5
        >>> image_processed = image_processed.numpy().astype(np.uint8)
        >>> image_pil = PIL.Image.fromarray(image_processed[0])

        >>> # save image
        >>> image_pil.save("test.png")
        ```

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """

        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)

        # set step values
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):
            # 1. predict noise model_output
            model_output = self.unet(image, t)[0]

            # 2. predict previous mean of image x_t-1 and add variance depending on eta
            # eta corresponds to η in paper and should be between [0, 1]
            # do x_t -> x_t-1
            image = self.scheduler.step(
                model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
            )[0]

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

mindone.diffusers.DDIMPipeline.__call__(batch_size=1, generator=None, eta=0.0, num_inference_steps=50, use_clipped_model_output=None, output_type='pil', return_dict=False)

The call function to the pipeline for generation.

PARAMETER DESCRIPTION
batch_size

The number of images to generate.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

generator

A np.random.Generator to make generation deterministic.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

eta

Corresponds to parameter eta (η) from the DDIM paper. Only applies to the [~schedulers.DDIMScheduler], and is ignored in other schedulers. A value of 0 corresponds to DDIM and 1 corresponds to DDPM.

TYPE: `float`, *optional*, defaults to 0.0 DEFAULT: 0.0

num_inference_steps

The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

TYPE: `int`, *optional*, defaults to 50 DEFAULT: 50

use_clipped_model_output

If True or False, see documentation for [DDIMScheduler.step]. If None, nothing is passed downstream to the scheduler (use None for schedulers which don't support this argument).

TYPE: `bool`, *optional*, defaults to `None` DEFAULT: None

output_type

The output format of the generated image. Choose between PIL.Image or np.array.

TYPE: `str`, *optional*, defaults to `"pil"` DEFAULT: 'pil'

return_dict

Whether or not to return a [~pipelines.ImagePipelineOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

>>> from mindone.diffusers import DDIMPipeline
>>> import PIL.Image
>>> import numpy as np

>>> # load model and scheduler
>>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom")

>>> # run pipeline in inference (sample random noise and denoise)
>>> image = pipe(eta=0.0, num_inference_steps=50)

>>> # process image to PIL
>>> image_processed = image.cpu().permute(0, 2, 3, 1)
>>> image_processed = (image_processed + 1.0) * 127.5
>>> image_processed = image_processed.numpy().astype(np.uint8)
>>> image_pil = PIL.Image.fromarray(image_processed[0])

>>> # save image
>>> image_pil.save("test.png")
RETURNS DESCRIPTION
Union[ImagePipelineOutput, Tuple]

[~pipelines.ImagePipelineOutput] or tuple: If return_dict is True, [~pipelines.ImagePipelineOutput] is returned, otherwise a tuple is returned where the first element is a list with the generated images

Source code in mindone/diffusers/pipelines/ddim/pipeline_ddim.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __call__(
    self,
    batch_size: int = 1,
    generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
    eta: float = 0.0,
    num_inference_steps: int = 50,
    use_clipped_model_output: Optional[bool] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = False,
) -> Union[ImagePipelineOutput, Tuple]:
    r"""
    The call function to the pipeline for generation.

    Args:
        batch_size (`int`, *optional*, defaults to 1):
            The number of images to generate.
        generator (`np.random.Generator`, *optional*):
            A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
            generation deterministic.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
            DDIM and `1` corresponds to DDPM.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        use_clipped_model_output (`bool`, *optional*, defaults to `None`):
            If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
            downstream to the scheduler (use `None` for schedulers which don't support this argument).
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

    Example:

    ```py
    >>> from mindone.diffusers import DDIMPipeline
    >>> import PIL.Image
    >>> import numpy as np

    >>> # load model and scheduler
    >>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom")

    >>> # run pipeline in inference (sample random noise and denoise)
    >>> image = pipe(eta=0.0, num_inference_steps=50)

    >>> # process image to PIL
    >>> image_processed = image.cpu().permute(0, 2, 3, 1)
    >>> image_processed = (image_processed + 1.0) * 127.5
    >>> image_processed = image_processed.numpy().astype(np.uint8)
    >>> image_pil = PIL.Image.fromarray(image_processed[0])

    >>> # save image
    >>> image_pil.save("test.png")
    ```

    Returns:
        [`~pipelines.ImagePipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
            returned where the first element is a list with the generated images
    """

    # Sample gaussian noise to begin loop
    if isinstance(self.unet.config.sample_size, int):
        image_shape = (
            batch_size,
            self.unet.config.in_channels,
            self.unet.config.sample_size,
            self.unet.config.sample_size,
        )
    else:
        image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)

    # set step values
    self.scheduler.set_timesteps(num_inference_steps)

    for t in self.progress_bar(self.scheduler.timesteps):
        # 1. predict noise model_output
        model_output = self.unet(image, t)[0]

        # 2. predict previous mean of image x_t-1 and add variance depending on eta
        # eta corresponds to η in paper and should be between [0, 1]
        # do x_t -> x_t-1
        image = self.scheduler.step(
            model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
        )[0]

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.permute(0, 2, 3, 1).numpy()
    if output_type == "pil":
        image = self.numpy_to_pil(image)

    if not return_dict:
        return (image,)

    return ImagePipelineOutput(images=image)

mindone.diffusers.pipelines.pipeline_utils.ImagePipelineOutput dataclass

Bases: BaseOutput

Output class for image pipelines.

Source code in mindone/diffusers/pipelines/pipeline_utils.py
69
70
71
72
73
74
75
76
77
78
79
80
@dataclass
class ImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
    """

    images: Union[List[PIL.Image.Image], np.ndarray]