Skip to content

KDPM2DiscreteScheduler

The KDPM2DiscreteScheduler is inspired by the Elucidating the Design Space of Diffusion-Based Generative Models paper, and the scheduler is ported from and created by Katherine Crowson.

The original codebase can be found at crowsonkb/k-diffusion.

mindone.diffusers.KDPM2DiscreteScheduler

Bases: SchedulerMixin, ConfigMixin

KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the Elucidating the Design Space of Diffusion-Based Generative Models paper.

This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving.

PARAMETER DESCRIPTION
num_train_timesteps

The number of diffusion steps to train the model.

TYPE: `int`, defaults to 1000 DEFAULT: 1000

beta_start

The starting beta value of inference.

TYPE: `float`, defaults to 0.00085 DEFAULT: 0.00085

beta_end

The final beta value.

TYPE: `float`, defaults to 0.012 DEFAULT: 0.012

beta_schedule

The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from linear or scaled_linear.

TYPE: `str`, defaults to `"linear"` DEFAULT: 'linear'

trained_betas

Pass an array of betas directly to the constructor to bypass beta_start and beta_end.

TYPE: `np.ndarray`, *optional* DEFAULT: None

use_karras_sigmas

Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If True, the sigmas are determined according to a sequence of noise levels {σi}.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

prediction_type

Prediction type of the scheduler function; can be epsilon (predicts the noise of the diffusion process), sample (directly predicts the noisy sample) orv_prediction` (see section 2.4 of Imagen Video paper).

TYPE: `str`, defaults to `epsilon`, *optional* DEFAULT: 'epsilon'

timestep_spacing

The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information.

TYPE: `str`, defaults to `"linspace"` DEFAULT: 'linspace'

steps_offset

An offset added to the inference steps, as required by some model families.

TYPE: `int`, defaults to 0 DEFAULT: 0

Source code in mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
    """
    KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of
    Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper.

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.00085):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.012):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear` or `scaled_linear`.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps, as required by some model families.
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 2

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.00085,  # sensible defaults
        beta_end: float = 0.012,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        use_karras_sigmas: Optional[bool] = False,
        prediction_type: str = "epsilon",
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
    ):
        if trained_betas is not None:
            self.betas = ms.tensor(trained_betas, dtype=ms.float32)
        elif beta_schedule == "linear":
            self.betas = ms.tensor(np.linspace(beta_start, beta_end, num_train_timesteps), dtype=ms.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = (
                ms.tensor(np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps), dtype=ms.float32) ** 2
            )
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = ops.cumprod(self.alphas, dim=0)

        #  set all values
        self.set_timesteps(num_train_timesteps, num_train_timesteps)

        self._step_index = None
        self._begin_index = None

    @property
    def init_noise_sigma(self):
        # standard deviation of the initial noise distribution
        if self.config.timestep_spacing in ["linspace", "trailing"]:
            return self.sigmas.max()

        return (self.sigmas.max() ** 2 + 1) ** 0.5

    @property
    def step_index(self):
        """
        The index counter for current timestep. It will increase 1 after each scheduler step.
        """
        return self._step_index

    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

    def scale_model_input(
        self,
        sample: ms.Tensor,
        timestep: Union[float, ms.Tensor],
    ) -> ms.Tensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`ms.Tensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

        Returns:
            `ms.Tensor`:
                A scaled input sample.
        """
        if self.step_index is None:
            self._init_step_index(timestep)

        if self.state_in_first_order:
            sigma = self.sigmas[self.step_index]
        else:
            sigma = self.sigmas_interpol[self.step_index]

        sample = (sample / ((sigma**2 + 1) ** 0.5)).to(sample.dtype)
        return sample

    def set_timesteps(
        self,
        num_inference_steps: int,
        num_train_timesteps: Optional[int] = None,
    ):
        """
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
        """
        self.num_inference_steps = num_inference_steps

        num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
            timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
        elif self.config.timestep_spacing == "leading":
            step_ratio = num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )

        sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy()
        log_sigmas = np.log(sigmas)
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

        if self.config.use_karras_sigmas:
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round().astype(np.float32)

        self.log_sigmas = ms.Tensor(log_sigmas)
        sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
        sigmas = ms.Tensor(sigmas)

        # interpolate sigmas
        sigmas_interpol = sigmas.log().lerp(ms.Tensor(np.roll(sigmas.asnumpy(), 1)).log(), 0.5).exp()

        self.sigmas = ops.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
        self.sigmas_interpol = ops.cat(
            [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
        )

        timesteps = ms.Tensor(timesteps)

        # interpolate timesteps
        log_sigmas = self.log_sigmas
        timesteps_interpol = np.array(
            [self._sigma_to_t(sigma_interpol, log_sigmas.asnumpy()) for sigma_interpol in sigmas_interpol.asnumpy()]
        )
        timesteps_interpol = ms.tensor(timesteps_interpol, dtype=timesteps.dtype)
        interleaved_timesteps = ops.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), axis=-1).flatten()

        self.timesteps = ops.cat([timesteps[:1], interleaved_timesteps])

        self.sample = None

        self._step_index = None
        self._begin_index = None

    @property
    def state_in_first_order(self):
        return self.sample is None

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps

        if (schedule_timesteps == timestep).sum() > 1:
            pos = 1
        else:
            pos = 0

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        indices = (schedule_timesteps == timestep).nonzero()

        return int(indices[pos])

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
        if self.begin_index is None:
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
        log_sigma = np.log(np.maximum(sigma, 1e-10))

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
    def _convert_to_karras(self, in_sigmas: ms.Tensor, num_inference_steps) -> ms.Tensor:
        """Constructs the noise schedule of Karras et al. (2022)."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

    def step(
        self,
        model_output: Union[ms.Tensor, np.ndarray],
        timestep: Union[float, ms.Tensor],
        sample: Union[ms.Tensor, np.ndarray],
        return_dict: bool = False,
    ) -> Union[SchedulerOutput, Tuple]:
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`ms.Tensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.

        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
        """
        if self.step_index is None:
            self._init_step_index(timestep)

        if self.state_in_first_order:
            sigma = self.sigmas[self.step_index]
            sigma_interpol = self.sigmas_interpol[self.step_index + 1]
            sigma_next = self.sigmas[self.step_index + 1]
        else:
            # 2nd order / KDPM2's method
            sigma = self.sigmas[self.step_index - 1]
            sigma_interpol = self.sigmas_interpol[self.step_index]
            sigma_next = self.sigmas[self.step_index]

        # currently only gamma=0 is supported. This usually works best anyways.
        # We can support gamma in the future but then need to scale the timestep before
        # passing it to the model which requires a change in API
        gamma = 0
        sigma_hat = sigma * (gamma + 1)  # Note: sigma_hat == sigma for now

        # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        if self.config.prediction_type == "epsilon":
            sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
            pred_original_sample = sample - sigma_input.to(model_output.dtype) * model_output
        elif self.config.prediction_type == "v_prediction":
            sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
            pred_original_sample = (model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5)).to(
                model_output.dtype
            ) + (sample / (sigma_input**2 + 1)).to(sample.dtype)
        elif self.config.prediction_type == "sample":
            raise NotImplementedError("prediction_type not implemented yet: sample")
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
            )

        if self.state_in_first_order:
            # 2. Convert to an ODE derivative for 1st order
            derivative = ((sample - pred_original_sample) / sigma_hat).to(sample.dtype)
            # 3. delta timestep
            dt = sigma_interpol - sigma_hat

            # store for 2nd order step
            self.sample = sample
        else:
            # DPM-Solver-2
            # 2. Convert to an ODE derivative for 2nd order
            derivative = ((sample - pred_original_sample) / sigma_interpol).to(sample.dtype)

            # 3. delta timestep
            dt = sigma_next - sigma_hat

            sample = self.sample
            self.sample = None

        # upon completion increase step index by one
        self._step_index += 1

        prev_sample = sample + (derivative * dt).to(derivative.dtype)

        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
    def add_noise(
        self,
        original_samples: ms.Tensor,
        noise: ms.Tensor,
        timesteps: ms.Tensor,
    ) -> ms.Tensor:
        broadcast_shape = original_samples.shape
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(dtype=original_samples.dtype)
        schedule_timesteps = self.timesteps

        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
        else:
            # add noise is called before first denoising step to create initial latent(img2img)
            step_indices = [self.begin_index] * timesteps.shape[0]

        sigma = sigmas[step_indices].flatten()
        # while len(sigma.shape) < len(original_samples.shape):
        #     sigma = sigma.unsqueeze(-1)
        sigma = ops.reshape(sigma, (timesteps.shape[0],) + (1,) * (len(broadcast_shape) - 1))

        noisy_samples = original_samples + noise * sigma
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps

mindone.diffusers.KDPM2DiscreteScheduler.begin_index property

The index for the first timestep. It should be set from pipeline with set_begin_index method.

mindone.diffusers.KDPM2DiscreteScheduler.step_index property

The index counter for current timestep. It will increase 1 after each scheduler step.

mindone.diffusers.KDPM2DiscreteScheduler.scale_model_input(sample, timestep)

Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep.

PARAMETER DESCRIPTION
sample

The input sample.

TYPE: `ms.Tensor`

timestep

The current timestep in the diffusion chain.

TYPE: `int`, *optional*

RETURNS DESCRIPTION
Tensor

ms.Tensor: A scaled input sample.

Source code in mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def scale_model_input(
    self,
    sample: ms.Tensor,
    timestep: Union[float, ms.Tensor],
) -> ms.Tensor:
    """
    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
    current timestep.

    Args:
        sample (`ms.Tensor`):
            The input sample.
        timestep (`int`, *optional*):
            The current timestep in the diffusion chain.

    Returns:
        `ms.Tensor`:
            A scaled input sample.
    """
    if self.step_index is None:
        self._init_step_index(timestep)

    if self.state_in_first_order:
        sigma = self.sigmas[self.step_index]
    else:
        sigma = self.sigmas_interpol[self.step_index]

    sample = (sample / ((sigma**2 + 1) ** 0.5)).to(sample.dtype)
    return sample

mindone.diffusers.KDPM2DiscreteScheduler.set_begin_index(begin_index=0)

Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

PARAMETER DESCRIPTION
begin_index

The begin index for the scheduler.

TYPE: `int` DEFAULT: 0

Source code in mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
169
170
171
172
173
174
175
176
177
def set_begin_index(self, begin_index: int = 0):
    """
    Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

    Args:
        begin_index (`int`):
            The begin index for the scheduler.
    """
    self._begin_index = begin_index

mindone.diffusers.KDPM2DiscreteScheduler.set_timesteps(num_inference_steps, num_train_timesteps=None)

Sets the discrete timesteps used for the diffusion chain (to be run before inference).

PARAMETER DESCRIPTION
num_inference_steps

The number of diffusion steps used when generating samples with a pre-trained model.

TYPE: `int`

Source code in mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def set_timesteps(
    self,
    num_inference_steps: int,
    num_train_timesteps: Optional[int] = None,
):
    """
    Sets the discrete timesteps used for the diffusion chain (to be run before inference).

    Args:
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model.
    """
    self.num_inference_steps = num_inference_steps

    num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

    # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
    if self.config.timestep_spacing == "linspace":
        timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
    elif self.config.timestep_spacing == "leading":
        step_ratio = num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
        timesteps += self.config.steps_offset
    elif self.config.timestep_spacing == "trailing":
        step_ratio = num_train_timesteps / self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
        timesteps -= 1
    else:
        raise ValueError(
            f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
        )

    sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy()
    log_sigmas = np.log(sigmas)
    sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

    if self.config.use_karras_sigmas:
        sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
        timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round().astype(np.float32)

    self.log_sigmas = ms.Tensor(log_sigmas)
    sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
    sigmas = ms.Tensor(sigmas)

    # interpolate sigmas
    sigmas_interpol = sigmas.log().lerp(ms.Tensor(np.roll(sigmas.asnumpy(), 1)).log(), 0.5).exp()

    self.sigmas = ops.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
    self.sigmas_interpol = ops.cat(
        [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
    )

    timesteps = ms.Tensor(timesteps)

    # interpolate timesteps
    log_sigmas = self.log_sigmas
    timesteps_interpol = np.array(
        [self._sigma_to_t(sigma_interpol, log_sigmas.asnumpy()) for sigma_interpol in sigmas_interpol.asnumpy()]
    )
    timesteps_interpol = ms.tensor(timesteps_interpol, dtype=timesteps.dtype)
    interleaved_timesteps = ops.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), axis=-1).flatten()

    self.timesteps = ops.cat([timesteps[:1], interleaved_timesteps])

    self.sample = None

    self._step_index = None
    self._begin_index = None

mindone.diffusers.KDPM2DiscreteScheduler.step(model_output, timestep, sample, return_dict=False)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).

PARAMETER DESCRIPTION
model_output

The direct output from learned diffusion model.

TYPE: `ms.Tensor`

timestep

The current discrete timestep in the diffusion chain.

TYPE: `float`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor`

return_dict

Whether or not to return a [~schedulers.scheduling_utils.SchedulerOutput] or tuple.

TYPE: `bool` DEFAULT: False

RETURNS DESCRIPTION
Union[SchedulerOutput, Tuple]

[~schedulers.scheduling_utils.SchedulerOutput] or tuple: If return_dict is True, [~schedulers.scheduling_utils.SchedulerOutput] is returned, otherwise a tuple is returned where the first element is the sample tensor.

Source code in mindone/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def step(
    self,
    model_output: Union[ms.Tensor, np.ndarray],
    timestep: Union[float, ms.Tensor],
    sample: Union[ms.Tensor, np.ndarray],
    return_dict: bool = False,
) -> Union[SchedulerOutput, Tuple]:
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`ms.Tensor`):
            The direct output from learned diffusion model.
        timestep (`float`):
            The current discrete timestep in the diffusion chain.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.
        return_dict (`bool`):
            Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.

    Returns:
        [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
            tuple is returned where the first element is the sample tensor.
    """
    if self.step_index is None:
        self._init_step_index(timestep)

    if self.state_in_first_order:
        sigma = self.sigmas[self.step_index]
        sigma_interpol = self.sigmas_interpol[self.step_index + 1]
        sigma_next = self.sigmas[self.step_index + 1]
    else:
        # 2nd order / KDPM2's method
        sigma = self.sigmas[self.step_index - 1]
        sigma_interpol = self.sigmas_interpol[self.step_index]
        sigma_next = self.sigmas[self.step_index]

    # currently only gamma=0 is supported. This usually works best anyways.
    # We can support gamma in the future but then need to scale the timestep before
    # passing it to the model which requires a change in API
    gamma = 0
    sigma_hat = sigma * (gamma + 1)  # Note: sigma_hat == sigma for now

    # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
    if self.config.prediction_type == "epsilon":
        sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
        pred_original_sample = sample - sigma_input.to(model_output.dtype) * model_output
    elif self.config.prediction_type == "v_prediction":
        sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
        pred_original_sample = (model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5)).to(
            model_output.dtype
        ) + (sample / (sigma_input**2 + 1)).to(sample.dtype)
    elif self.config.prediction_type == "sample":
        raise NotImplementedError("prediction_type not implemented yet: sample")
    else:
        raise ValueError(
            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
        )

    if self.state_in_first_order:
        # 2. Convert to an ODE derivative for 1st order
        derivative = ((sample - pred_original_sample) / sigma_hat).to(sample.dtype)
        # 3. delta timestep
        dt = sigma_interpol - sigma_hat

        # store for 2nd order step
        self.sample = sample
    else:
        # DPM-Solver-2
        # 2. Convert to an ODE derivative for 2nd order
        derivative = ((sample - pred_original_sample) / sigma_interpol).to(sample.dtype)

        # 3. delta timestep
        dt = sigma_next - sigma_hat

        sample = self.sample
        self.sample = None

    # upon completion increase step index by one
    self._step_index += 1

    prev_sample = sample + (derivative * dt).to(derivative.dtype)

    if not return_dict:
        return (prev_sample,)

    return SchedulerOutput(prev_sample=prev_sample)