Skip to content

ConsistencyDecoderScheduler

This scheduler is a part of the [ConsistencyDecoderPipeline] and was introduced in DALL-E 3.

The original codebase can be found at openai/consistency_models.

mindone.diffusers.schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler

Bases: SchedulerMixin, ConfigMixin

Source code in mindone/diffusers/schedulers/scheduling_consistency_decoder.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1024,
        sigma_data: float = 0.5,
    ):
        betas = betas_for_alpha_bar(num_train_timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = ops.cumprod(alphas, dim=0)

        self.sqrt_alphas_cumprod = ops.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = ops.sqrt(1.0 - alphas_cumprod)

        sigmas = ops.sqrt(1.0 / alphas_cumprod - 1)

        sqrt_recip_alphas_cumprod = ops.sqrt(1.0 / alphas_cumprod)

        self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
        self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
        self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5

        # setable values
        self.timesteps = ms.tensor([1008, 512])

    def set_timesteps(
        self,
        num_inference_steps: Optional[int] = None,
    ):
        if num_inference_steps != 2:
            raise ValueError("Currently more than 2 inference steps are not supported.")

        self.timesteps = ms.tensor([1008, 512])

    @property
    def init_noise_sigma(self):
        return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]

    def scale_model_input(self, sample: ms.Tensor, timestep: Optional[int] = None) -> ms.Tensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`ms.Tensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

        Returns:
            `ms.Tensor`:
                A scaled input sample.
        """
        return (sample * self.c_in[timestep]).to(sample.dtype)

    def step(
        self,
        model_output: ms.Tensor,
        timestep: Union[float, ms.Tensor],
        sample: ms.Tensor,
        generator: Optional[np.random.Generator] = None,
        return_dict: bool = False,
    ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`ms.Tensor`):
                The direct output from the learned diffusion model.
            timestep (`float`):
                The current timestep in the diffusion chain.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.
            generator (`np.random.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a
                [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
                If return_dict is `True`,
                [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
                a tuple is returned where the first element is the sample tensor.
        """
        x_0 = (
            self.c_out[timestep].to(model_output.dtype) * model_output + self.c_skip[timestep].to(sample.dtype) * sample
        )

        timestep_idx = (self.timesteps == timestep).nonzero()[0]

        if timestep_idx == len(self.timesteps) - 1:
            prev_sample = x_0
        else:
            noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype)
            prev_sample = (
                self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
                + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
            )

        if not return_dict:
            return (prev_sample,)

        return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)

mindone.diffusers.schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler.scale_model_input(sample, timestep=None)

Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep.

PARAMETER DESCRIPTION
sample

The input sample.

TYPE: `ms.Tensor`

timestep

The current timestep in the diffusion chain.

TYPE: `int`, *optional* DEFAULT: None

RETURNS DESCRIPTION
Tensor

ms.Tensor: A scaled input sample.

Source code in mindone/diffusers/schedulers/scheduling_consistency_decoder.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def scale_model_input(self, sample: ms.Tensor, timestep: Optional[int] = None) -> ms.Tensor:
    """
    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
    current timestep.

    Args:
        sample (`ms.Tensor`):
            The input sample.
        timestep (`int`, *optional*):
            The current timestep in the diffusion chain.

    Returns:
        `ms.Tensor`:
            A scaled input sample.
    """
    return (sample * self.c_in[timestep]).to(sample.dtype)

mindone.diffusers.schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler.step(model_output, timestep, sample, generator=None, return_dict=False)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).

PARAMETER DESCRIPTION
model_output

The direct output from the learned diffusion model.

TYPE: `ms.Tensor`

timestep

The current timestep in the diffusion chain.

TYPE: `float`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor`

generator

A random number generator.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

return_dict

Whether or not to return a [~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput] or tuple.

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: False

RETURNS DESCRIPTION
Union[ConsistencyDecoderSchedulerOutput, Tuple]

[~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput] or tuple: If return_dict is True, [~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput] is returned, otherwise a tuple is returned where the first element is the sample tensor.

Source code in mindone/diffusers/schedulers/scheduling_consistency_decoder.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def step(
    self,
    model_output: ms.Tensor,
    timestep: Union[float, ms.Tensor],
    sample: ms.Tensor,
    generator: Optional[np.random.Generator] = None,
    return_dict: bool = False,
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`ms.Tensor`):
            The direct output from the learned diffusion model.
        timestep (`float`):
            The current timestep in the diffusion chain.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.
        generator (`np.random.Generator`, *optional*):
            A random number generator.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a
            [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.

    Returns:
        [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
            If return_dict is `True`,
            [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
            a tuple is returned where the first element is the sample tensor.
    """
    x_0 = (
        self.c_out[timestep].to(model_output.dtype) * model_output + self.c_skip[timestep].to(sample.dtype) * sample
    )

    timestep_idx = (self.timesteps == timestep).nonzero()[0]

    if timestep_idx == len(self.timesteps) - 1:
        prev_sample = x_0
    else:
        noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype)
        prev_sample = (
            self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
            + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
        )

    if not return_dict:
        return (prev_sample,)

    return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)