Skip to content

Consistency Models

Consistency Models were proposed in Consistency Models by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.

The abstract from the paper is:

Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.

The original codebase can be found at openai/consistency_models, and additional checkpoints are available at openai.

The pipeline was contributed by dg845 and ayushtues. ❤️

Tips

For an additional speed-up, please use MindSpore Graph Mode to generate multiple images:

  import mindspore as ms
  from mindone.diffusers import ConsistencyModelPipeline

  + ms.set_context(mode=0)

  # Load the cd_bedroom256_lpips checkpoint.
  model_id_or_path = "openai/diffusers-cd_bedroom256_lpips"
  pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, mindspore_dtype=ms.float16)

  # one step sampling
  image = pipe(num_inference_steps=1)[0][0]
  image.save("one_step_sample.png")

  # one step sampling, class-conditional image generation
  # Imagenet-64 class label 145 corresponds to king penguins
  image = pipe(num_inference_steps=1, class_labels=145)[0][0]
  image.save("cd_imagenet64_12_onestep_sample_penguin.png")

  # Multistep sampling
  # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:
  # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
  image = pipe(num_inference_steps=None, timesteps=[22, 0], class_label=145)[0][0]
  image.save("multi_step_sample.png")

mindone.diffusers.ConsistencyModelPipeline

Bases: DiffusionPipeline

Pipeline for unconditional or class-conditional image generation.

This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).

PARAMETER DESCRIPTION
unet

A UNet2DModel to denoise the encoded image latents.

TYPE: [`UNet2DModel`]

scheduler

A scheduler to be used in combination with unet to denoise the encoded image latents. Currently only compatible with [CMStochasticIterativeScheduler].

TYPE: [`SchedulerMixin`]

Source code in mindone/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class ConsistencyModelPipeline(DiffusionPipeline):
    r"""
    Pipeline for unconditional or class-conditional image generation.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Args:
        unet ([`UNet2DModel`]):
            A `UNet2DModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
            compatible with [`CMStochasticIterativeScheduler`].
    """

    def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
        super().__init__()

        self.register_modules(
            unet=unet,
            scheduler=scheduler,
        )

        self.safety_checker = None

    def prepare_latents(self, batch_size, num_channels, height, width, dtype, generator, latents=None):
        shape = (batch_size, num_channels, height, width)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(shape, generator=generator, dtype=dtype)
        else:
            latents = latents.to(dtype=dtype)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = (latents * self.scheduler.init_noise_sigma).to(dtype)
        return latents

    # Follows diffusers.VaeImageProcessor.postprocess
    def postprocess_image(self, sample: ms.Tensor, output_type: str = "pil"):
        if output_type not in ["ms", "np", "pil"]:
            raise ValueError(
                f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
            )

        # Equivalent to diffusers.VaeImageProcessor.denormalize
        sample = (sample / 2 + 0.5).clamp(0, 1)
        if output_type == "ms":
            return sample

        # Equivalent to diffusers.VaeImageProcessor.ms_to_numpy
        sample = sample.permute((0, 2, 3, 1)).numpy()
        if output_type == "np":
            return sample

        # Output_type must be 'pil'
        sample = self.numpy_to_pil(sample)
        return sample

    def prepare_class_labels(self, batch_size, class_labels=None):
        if self.unet.config.num_class_embeds is not None:
            if isinstance(class_labels, list):
                class_labels = ms.tensor(class_labels, dtype=ms.int64)
            elif isinstance(class_labels, int):
                assert batch_size == 1, "Batch size must be 1 if classes is an int"
                class_labels = ms.tensor([class_labels], dtype=ms.int64)
            elif class_labels is None:
                # Randomly generate batch_size class labels
                # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
                class_labels = ops.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
        else:
            class_labels = None
        return class_labels

    def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
        if num_inference_steps is None and timesteps is None:
            raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")

        if num_inference_steps is not None and timesteps is not None:
            logger.warning(
                f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
                " `timesteps` will be used over `num_inference_steps`."
            )

        if latents is not None:
            expected_shape = (batch_size, 3, img_size, img_size)
            if latents.shape != expected_shape:
                raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")

        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

    def __call__(
        self,
        batch_size: int = 1,
        class_labels: Optional[Union[ms.Tensor, List[int], int]] = None,
        num_inference_steps: int = 1,
        timesteps: List[int] = None,
        generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
        latents: Optional[ms.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = False,
        callback: Optional[Callable[[int, int, ms.Tensor], None]] = None,
        callback_steps: int = 1,
    ):
        r"""
        Args:
            batch_size (`int`, *optional*, defaults to 1):
                The number of images to generate.
            class_labels (`ms.Tensor` or `List[int]` or `int`, *optional*):
                Optional class labels for conditioning class-conditional consistency models. Not used if the model is
                not class-conditional.
            num_inference_steps (`int`, *optional*, defaults to 1):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`List[int]`, *optional*):
                Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
                timesteps are used. Must be in descending order.
            generator (`np.random.Generator`, *optional*):
                A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
                generation deterministic.
            latents (`ms.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
            callback (`Callable`, *optional*):
                A function that calls every `callback_steps` steps during inference. The function is called with the
                following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function is called. If not specified, the callback is called at
                every step.

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        """
        # 0. Prepare call parameters
        img_size = self.unet.config.sample_size

        # 1. Check inputs
        self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)

        # 2. Prepare image latents
        # Sample image latents x_0 ~ N(0, sigma_0^2 * I)
        sample = self.prepare_latents(
            batch_size=batch_size,
            num_channels=self.unet.config.in_channels,
            height=img_size,
            width=img_size,
            dtype=self.unet.dtype,
            generator=generator,
            latents=latents,
        )

        # 3. Handle class_labels for class-conditional models
        class_labels = self.prepare_class_labels(batch_size, class_labels=class_labels)

        # 4. Prepare timesteps
        if timesteps is not None:
            self.scheduler.set_timesteps(timesteps=timesteps)
            timesteps = self.scheduler.timesteps
            num_inference_steps = len(timesteps)
        else:
            self.scheduler.set_timesteps(num_inference_steps)
            timesteps = self.scheduler.timesteps

        # 5. Denoising loop
        # Multistep sampling: implements Algorithm 1 in the paper
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                scaled_sample = self.scheduler.scale_model_input(sample, t)
                model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]

                sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]

                # call the callback, if provided
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, sample)

        # 6. Post-process image sample
        image = self.postprocess_image(sample, output_type=output_type)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

mindone.diffusers.ConsistencyModelPipeline.__call__(batch_size=1, class_labels=None, num_inference_steps=1, timesteps=None, generator=None, latents=None, output_type='pil', return_dict=False, callback=None, callback_steps=1)

PARAMETER DESCRIPTION
batch_size

The number of images to generate.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

class_labels

Optional class labels for conditioning class-conditional consistency models. Not used if the model is not class-conditional.

TYPE: `ms.Tensor` or `List[int]` or `int`, *optional* DEFAULT: None

num_inference_steps

The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

timesteps

Custom timesteps to use for the denoising process. If not defined, equal spaced num_inference_steps timesteps are used. Must be in descending order.

TYPE: `List[int]`, *optional* DEFAULT: None

generator

A np.random.Generator to make generation deterministic.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

latents

Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random generator.

TYPE: `ms.Tensor`, *optional* DEFAULT: None

output_type

The output format of the generated image. Choose between PIL.Image or np.array.

TYPE: `str`, *optional*, defaults to `"pil"` DEFAULT: 'pil'

return_dict

Whether or not to return a [~pipelines.ImagePipelineOutput] instead of a plain tuple.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

callback

A function that calls every callback_steps steps during inference. The function is called with the following arguments: callback(step: int, timestep: int, latents: ms.Tensor).

TYPE: `Callable`, *optional* DEFAULT: None

callback_steps

The frequency at which the callback function is called. If not specified, the callback is called at every step.

TYPE: `int`, *optional*, defaults to 1 DEFAULT: 1

RETURNS DESCRIPTION

[~pipelines.ImagePipelineOutput] or tuple: If return_dict is True, [~pipelines.ImagePipelineOutput] is returned, otherwise a tuple is returned where the first element is a list with the generated images.

Source code in mindone/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def __call__(
    self,
    batch_size: int = 1,
    class_labels: Optional[Union[ms.Tensor, List[int], int]] = None,
    num_inference_steps: int = 1,
    timesteps: List[int] = None,
    generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
    latents: Optional[ms.Tensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = False,
    callback: Optional[Callable[[int, int, ms.Tensor], None]] = None,
    callback_steps: int = 1,
):
    r"""
    Args:
        batch_size (`int`, *optional*, defaults to 1):
            The number of images to generate.
        class_labels (`ms.Tensor` or `List[int]` or `int`, *optional*):
            Optional class labels for conditioning class-conditional consistency models. Not used if the model is
            not class-conditional.
        num_inference_steps (`int`, *optional*, defaults to 1):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        timesteps (`List[int]`, *optional*):
            Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
            timesteps are used. Must be in descending order.
        generator (`np.random.Generator`, *optional*):
            A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make
            generation deterministic.
        latents (`ms.Tensor`, *optional*):
            Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor is generated by sampling using the supplied random `generator`.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
        callback (`Callable`, *optional*):
            A function that calls every `callback_steps` steps during inference. The function is called with the
            following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function is called. If not specified, the callback is called at
            every step.

    Examples:

    Returns:
        [`~pipelines.ImagePipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
            returned where the first element is a list with the generated images.
    """
    # 0. Prepare call parameters
    img_size = self.unet.config.sample_size

    # 1. Check inputs
    self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)

    # 2. Prepare image latents
    # Sample image latents x_0 ~ N(0, sigma_0^2 * I)
    sample = self.prepare_latents(
        batch_size=batch_size,
        num_channels=self.unet.config.in_channels,
        height=img_size,
        width=img_size,
        dtype=self.unet.dtype,
        generator=generator,
        latents=latents,
    )

    # 3. Handle class_labels for class-conditional models
    class_labels = self.prepare_class_labels(batch_size, class_labels=class_labels)

    # 4. Prepare timesteps
    if timesteps is not None:
        self.scheduler.set_timesteps(timesteps=timesteps)
        timesteps = self.scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        self.scheduler.set_timesteps(num_inference_steps)
        timesteps = self.scheduler.timesteps

    # 5. Denoising loop
    # Multistep sampling: implements Algorithm 1 in the paper
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            scaled_sample = self.scheduler.scale_model_input(sample, t)
            model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]

            sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]

            # call the callback, if provided
            progress_bar.update()
            if callback is not None and i % callback_steps == 0:
                callback(i, t, sample)

    # 6. Post-process image sample
    image = self.postprocess_image(sample, output_type=output_type)

    if not return_dict:
        return (image,)

    return ImagePipelineOutput(images=image)

mindone.diffusers.pipelines.ImagePipelineOutput dataclass

Bases: BaseOutput

Output class for image pipelines.

Source code in mindone/diffusers/pipelines/pipeline_utils.py
69
70
71
72
73
74
75
76
77
78
79
80
@dataclass
class ImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
    """

    images: Union[List[PIL.Image.Image], np.ndarray]