Skip to content

DPMSolverMultistepScheduler

DPMSolverMultistepScheduler is a multistep scheduler from DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps and DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.

DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality samples, and it can generate quite good samples even in 10 steps.

Tips

It is recommended to set solver_order to 2 for guide sampling, and solver_order=3 for unconditional sampling.

Dynamic thresholding from Imagen is supported, and for pixel-space diffusion models, you can set both algorithm_type="dpmsolver++" and thresholding=True to use the dynamic thresholding. This thresholding method is unsuitable for latent-space diffusion models such as Stable Diffusion.

The SDE variant of DPMSolver and DPM-Solver++ is also supported, but only for the first and second-order solvers. This is a fast SDE solver for the reverse diffusion SDE. It is recommended to use the second-order sde-dpmsolver++.

mindone.diffusers.DPMSolverMultistepScheduler

Bases: SchedulerMixin, ConfigMixin

DPMSolverMultistepScheduler is a fast dedicated high-order solver for diffusion ODEs.

This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving.

PARAMETER DESCRIPTION
num_train_timesteps

The number of diffusion steps to train the model.

TYPE: `int`, defaults to 1000 DEFAULT: 1000

beta_start

The starting beta value of inference.

TYPE: `float`, defaults to 0.0001 DEFAULT: 0.0001

beta_end

The final beta value.

TYPE: `float`, defaults to 0.02 DEFAULT: 0.02

beta_schedule

The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from linear, scaled_linear, or squaredcos_cap_v2.

TYPE: `str`, defaults to `"linear"` DEFAULT: 'linear'

trained_betas

Pass an array of betas directly to the constructor to bypass beta_start and beta_end.

TYPE: `np.ndarray`, *optional* DEFAULT: None

solver_order

The DPMSolver order which can be 1 or 2 or 3. It is recommended to use solver_order=2 for guided sampling, and solver_order=3 for unconditional sampling.

TYPE: `int`, defaults to 2 DEFAULT: 2

prediction_type

Prediction type of the scheduler function; can be epsilon (predicts the noise of the diffusion process), sample (directly predicts the noisy sample) orv_prediction` (see section 2.4 of Imagen Video paper).

TYPE: `str`, defaults to `epsilon`, *optional* DEFAULT: 'epsilon'

thresholding

Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion.

TYPE: `bool`, defaults to `False` DEFAULT: False

dynamic_thresholding_ratio

The ratio for the dynamic thresholding method. Valid only when thresholding=True.

TYPE: `float`, defaults to 0.995 DEFAULT: 0.995

sample_max_value

The threshold value for dynamic thresholding. Valid only when thresholding=True and algorithm_type="dpmsolver++".

TYPE: `float`, defaults to 1.0 DEFAULT: 1.0

algorithm_type

Algorithm type for the solver; can be dpmsolver, dpmsolver++, sde-dpmsolver or sde-dpmsolver++. The dpmsolver type implements the algorithms in the DPMSolver paper, and the dpmsolver++ type implements the algorithms in the DPMSolver++ paper. It is recommended to use dpmsolver++ or sde-dpmsolver++ with solver_order=2 for guided sampling like in Stable Diffusion.

TYPE: `str`, defaults to `dpmsolver++` DEFAULT: 'dpmsolver++'

solver_type

Solver type for the second-order solver; can be midpoint or heun. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use midpoint solvers.

TYPE: `str`, defaults to `midpoint` DEFAULT: 'midpoint'

lower_order_final

Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.

TYPE: `bool`, defaults to `True` DEFAULT: True

euler_at_final

Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring.

TYPE: `bool`, defaults to `False` DEFAULT: False

use_karras_sigmas

Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If True, the sigmas are determined according to a sequence of noise levels {σi}.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

use_lu_lambdas

Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If True, the sigmas and time steps are determined according to a sequence of lambda(t).

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

final_sigmas_type

The final sigma value for the noise schedule during the sampling process. If "sigma_min", the final sigma is the same as the last sigma in the training schedule. If zero, the final sigma is set to 0.

TYPE: `str`, defaults to `"zero"` DEFAULT: 'zero'

lambda_min_clipped

Clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for the cosine (squaredcos_cap_v2) noise schedule.

TYPE: `float`, defaults to `-inf` DEFAULT: -float('inf')

variance_type

Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance.

TYPE: `str`, *optional* DEFAULT: None

timestep_spacing

The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information.

TYPE: `str`, defaults to `"linspace"` DEFAULT: 'linspace'

steps_offset

An offset added to the inference steps, as required by some model families.

TYPE: `int`, defaults to 0 DEFAULT: 0

rescale_betas_zero_snr

Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to --offset_noise.

TYPE: `bool`, defaults to `False` DEFAULT: False

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    """
    `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++"`.
        algorithm_type (`str`, defaults to `dpmsolver++`):
            Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
            paper, and the `dpmsolver++` type implements the algorithms in the
            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
            `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
        solver_type (`str`, defaults to `midpoint`):
            Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
            sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
        euler_at_final (`bool`, defaults to `False`):
            Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
            richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
            steps, but sometimes may result in blurring.
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        use_lu_lambdas (`bool`, *optional*, defaults to `False`):
            Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
            the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
            `lambda(t)`.
        final_sigmas_type (`str`, defaults to `"zero"`):
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
        lambda_min_clipped (`float`, defaults to `-inf`):
            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
            cosine (`squaredcos_cap_v2`) noise schedule.
        variance_type (`str`, *optional*):
            Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
            contains the predicted Gaussian variance.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps, as required by some model families.
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
    """

    _compatibles = [e.name for e in KarrasDiffusionSchedulers]
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        solver_order: int = 2,
        prediction_type: str = "epsilon",
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "dpmsolver++",
        solver_type: str = "midpoint",
        lower_order_final: bool = True,
        euler_at_final: bool = False,
        use_karras_sigmas: Optional[bool] = False,
        use_lu_lambdas: Optional[bool] = False,
        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
        lambda_min_clipped: float = -float("inf"),
        variance_type: Optional[str] = None,
        timestep_spacing: str = "linspace",
        steps_offset: int = 0,
        rescale_betas_zero_snr: bool = False,
    ):
        if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
            deprecation_message = (
                f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. "
                f"Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
            )
            deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)

        if trained_betas is not None:
            self.betas = ms.tensor(trained_betas, dtype=ms.float32)
        elif beta_schedule == "linear":
            self.betas = ms.tensor(np.linspace(beta_start, beta_end, num_train_timesteps), dtype=ms.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = (
                ms.tensor(np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps), dtype=ms.float32) ** 2
            )
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")

        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = ops.cumprod(self.alphas, dim=0)

        if rescale_betas_zero_snr:
            # Close to 0 without being 0 so first sigma is not inf
            # FP16 smallest positive subnormal works well here
            self.alphas_cumprod[-1] = 2**-24

        # Currently we only support VP-type noise schedule
        self.alpha_t = ops.sqrt(self.alphas_cumprod)
        self.sigma_t = ops.sqrt(1 - self.alphas_cumprod)
        self.lambda_t = ops.log(self.alpha_t) - ops.log(self.sigma_t)
        self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # settings for DPM-Solver
        if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
            if algorithm_type == "deis":
                self.register_to_config(algorithm_type="dpmsolver++")
            else:
                raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")

        if solver_type not in ["midpoint", "heun"]:
            if solver_type in ["logrho", "bh1", "bh2"]:
                self.register_to_config(solver_type="midpoint")
            else:
                raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")

        if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
            raise ValueError(
                f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
            )

        # setable values
        self.num_inference_steps = None
        timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
        self.timesteps = ms.Tensor(timesteps)
        self.model_outputs = [None] * solver_order
        self.lower_order_nums = 0
        self._step_index = None
        self._begin_index = None

    @property
    def step_index(self):
        """
        The index counter for current timestep. It will increase 1 after each scheduler step.
        """
        return self._step_index

    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

    def set_timesteps(self, num_inference_steps: int = None, timesteps: Optional[List[int]] = None):
        """
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
                based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
                must be `None`, and `timestep_spacing` attribute will be ignored.
        """
        if num_inference_steps is None and timesteps is None:
            raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
        if num_inference_steps is not None and timesteps is not None:
            raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
        if timesteps is not None and self.config.use_karras_sigmas:
            raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
        if timesteps is not None and self.config.use_lu_lambdas:
            raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")

        if timesteps is not None:
            timesteps = np.array(timesteps).astype(np.int64)
        else:
            # Clipping the minimum of all lambda(t) for numerical stability.
            # This is critical for cosine (squaredcos_cap_v2) noise schedule.
            clipped_idx = ms.tensor(
                np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped), dtype=ms.int64
            )
            last_timestep = ((self.config.num_train_timesteps - clipped_idx).asnumpy()).item()

            # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
            if self.config.timestep_spacing == "linspace":
                timesteps = (
                    np.linspace(0, last_timestep - 1, num_inference_steps + 1)
                    .round()[::-1][:-1]
                    .copy()
                    .astype(np.int64)
                )
            elif self.config.timestep_spacing == "leading":
                step_ratio = last_timestep // (num_inference_steps + 1)
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = (
                    (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
                )
                timesteps += self.config.steps_offset
            elif self.config.timestep_spacing == "trailing":
                step_ratio = self.config.num_train_timesteps / num_inference_steps
                # creates integer timesteps by multiplying by ratio
                # casting to int to avoid issues when num_inference_step is power of 3
                timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
                timesteps -= 1
            else:
                raise ValueError(
                    f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
                )

        sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy()
        log_sigmas = np.log(sigmas)

        if self.config.use_karras_sigmas:
            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
        elif self.config.use_lu_lambdas:
            lambdas = np.flip(log_sigmas.copy())
            lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
            sigmas = np.exp(lambdas)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
        else:
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

        if self.config.final_sigmas_type == "sigma_min":
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
        elif self.config.final_sigmas_type == "zero":
            sigma_last = 0
        else:
            raise ValueError(
                f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
            )

        sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

        self.sigmas = ms.Tensor(sigmas)
        self.timesteps = ms.tensor(timesteps, dtype=ms.int64)

        self.num_inference_steps = len(timesteps)

        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
        self._begin_index = None

    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
    def _threshold_sample(self, sample: ms.Tensor) -> ms.Tensor:
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://arxiv.org/abs/2205.11487
        """
        dtype = sample.dtype
        batch_size, channels, *remaining_dims = sample.shape

        if dtype not in (ms.float32, ms.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims).item())

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = ms.Tensor.from_numpy(np.quantile(abs_sample.asnumpy(), self.config.dynamic_thresholding_ratio, axis=1))
        s = ops.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = ops.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

        sample = sample.reshape(batch_size, channels, *remaining_dims)
        sample = sample.to(dtype)

        return sample

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
    def _sigma_to_t(self, sigma, log_sigmas):
        # get log sigma
        log_sigma = np.log(np.maximum(sigma, 1e-10))

        # get distribution
        dists = log_sigma - log_sigmas[:, np.newaxis]

        # get sigmas range
        low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
        high_idx = low_idx + 1

        low = log_sigmas[low_idx]
        high = log_sigmas[high_idx]

        # interpolate sigmas
        w = (low - log_sigma) / (low - high)
        w = np.clip(w, 0, 1)

        # transform interpolation to time range
        t = (1 - w) * low_idx + w * high_idx
        t = t.reshape(sigma.shape)
        return t

    def _sigma_to_alpha_sigma_t(self, sigma):
        alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
        sigma_t = sigma * alpha_t

        return alpha_t, sigma_t

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
    def _convert_to_karras(self, in_sigmas: ms.Tensor, num_inference_steps) -> ms.Tensor:
        """Constructs the noise schedule of Karras et al. (2022)."""

        # Hack to make sure that other schedulers which copy this function don't break
        # TODO: Add this logic to the other schedulers
        if hasattr(self.config, "sigma_min"):
            sigma_min = self.config.sigma_min
        else:
            sigma_min = None

        if hasattr(self.config, "sigma_max"):
            sigma_max = self.config.sigma_max
        else:
            sigma_max = None

        sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
        sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

        rho = 7.0  # 7.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return sigmas

    def _convert_to_lu(self, in_lambdas: ms.Tensor, num_inference_steps) -> ms.Tensor:
        """Constructs the noise schedule of Lu et al. (2022)."""

        lambda_min: float = in_lambdas[-1].item()
        lambda_max: float = in_lambdas[0].item()

        rho = 1.0  # 1.0 is the value used in the paper
        ramp = np.linspace(0, 1, num_inference_steps)
        min_inv_rho = lambda_min ** (1 / rho)
        max_inv_rho = lambda_max ** (1 / rho)
        lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
        return lambdas

    def convert_model_output(
        self,
        model_output: ms.Tensor,
        *args,
        sample: ms.Tensor = None,
        **kwargs,
    ) -> ms.Tensor:
        """
        Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
        designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
        integral of the data prediction model.

        <Tip>

        The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
        prediction and data prediction models.

        </Tip>

        Args:
            model_output (`ms.Tensor`):
                The direct output from the learned diffusion model.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `ms.Tensor`:
                The converted model output.
        """
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        sigmas = self.sigmas.to(dtype=sample.dtype)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
                raise ValueError("missing `sample` as a required keyward argument")
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        # DPM-Solver++ needs to solve an integral of the data prediction model.
        if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
            if self.config.prediction_type == "epsilon":
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
                if self.config.variance_type in ["learned", "learned_range"]:
                    model_output = model_output[:, :3]
                sigma = sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
                x0_pred = (sample - sigma_t * model_output) / alpha_t
            elif self.config.prediction_type == "sample":
                x0_pred = model_output
            elif self.config.prediction_type == "v_prediction":
                sigma = sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
                x0_pred = alpha_t * sample - sigma_t * model_output
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                    " `v_prediction` for the DPMSolverMultistepScheduler."
                )

            if self.config.thresholding:
                x0_pred = self._threshold_sample(x0_pred)

            return x0_pred

        # DPM-Solver needs to solve an integral of the noise prediction model.
        elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
            if self.config.prediction_type == "epsilon":
                # DPM-Solver and DPM-Solver++ only need the "mean" output.
                if self.config.variance_type in ["learned", "learned_range"]:
                    epsilon = model_output[:, :3]
                else:
                    epsilon = model_output
            elif self.config.prediction_type == "sample":
                sigma = sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
                epsilon = (sample - alpha_t * model_output) / sigma_t
            elif self.config.prediction_type == "v_prediction":
                sigma = sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
                epsilon = alpha_t * model_output + sigma_t * sample
            else:
                raise ValueError(
                    f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                    " `v_prediction` for the DPMSolverMultistepScheduler."
                )

            if self.config.thresholding:
                sigma = sigmas[self.step_index]
                alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
                x0_pred = (sample - sigma_t * epsilon) / alpha_t
                x0_pred = self._threshold_sample(x0_pred)
                epsilon = (sample - alpha_t * x0_pred) / sigma_t

            return epsilon

    def dpm_solver_first_order_update(
        self,
        model_output: ms.Tensor,
        *args,
        sample: ms.Tensor = None,
        noise: Optional[ms.Tensor] = None,
        **kwargs,
    ) -> ms.Tensor:
        """
        One step for the first-order DPMSolver (equivalent to DDIM).

        Args:
            model_output (`ms.Tensor`):
                The direct output from the learned diffusion model.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `ms.Tensor`:
                The sample tensor at the previous timestep.
        """
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing `sample` as a required keyward argument")
        if timestep is not None:
            deprecate(
                "timesteps",
                "1.0.0",
                "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
        lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
        lambda_s = ops.log(alpha_s) - ops.log(sigma_s)

        h = lambda_t - lambda_s
        if self.config.algorithm_type == "dpmsolver++":
            x_t = (sigma_t / sigma_s) * sample - (alpha_t * (ops.exp(-h) - 1.0)) * model_output
        elif self.config.algorithm_type == "dpmsolver":
            x_t = (alpha_t / alpha_s) * sample - (sigma_t * (ops.exp(h) - 1.0)) * model_output
        elif self.config.algorithm_type == "sde-dpmsolver++":
            assert noise is not None
            x_t = (
                (sigma_t / sigma_s * ops.exp(-h)) * sample
                + (alpha_t * (1 - ops.exp(-2.0 * h))) * model_output
                + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
            )
        elif self.config.algorithm_type == "sde-dpmsolver":
            assert noise is not None
            x_t = (
                (alpha_t / alpha_s) * sample
                - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * model_output
                + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
            )
        return x_t

    def multistep_dpm_solver_second_order_update(
        self,
        model_output_list: List[ms.Tensor],
        *args,
        sample: ms.Tensor = None,
        noise: Optional[ms.Tensor] = None,
        **kwargs,
    ) -> ms.Tensor:
        """
        One step for the second-order multistep DPMSolver.

        Args:
            model_output_list (`List[ms.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `ms.Tensor`:
                The sample tensor at the previous timestep.
        """
        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing `sample` as a required keyward argument")
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

        lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
        lambda_s0 = ops.log(alpha_s0) - ops.log(sigma_s0)
        lambda_s1 = ops.log(alpha_s1) - ops.log(sigma_s1)

        m0, m1 = model_output_list[-1], model_output_list[-2]

        h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
        r0 = h_0 / h
        D0, D1 = m0, (1.0 / r0) * (m0 - m1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2211.01095 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (ops.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (ops.exp(-h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (ops.exp(-h) - 1.0)) * D0
                    + (alpha_t * ((ops.exp(-h) - 1.0) / h + 1.0)) * D1
                )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (ops.exp(h) - 1.0)) * D0
                    - 0.5 * (sigma_t * (ops.exp(h) - 1.0)) * D1
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - (sigma_t * (ops.exp(h) - 1.0)) * D0
                    - (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
                )
        elif self.config.algorithm_type == "sde-dpmsolver++":
            assert noise is not None
            if self.config.solver_type == "midpoint":
                x_t = (
                    (sigma_t / sigma_s0 * ops.exp(-h)) * sample
                    + (alpha_t * (1 - ops.exp(-2.0 * h))) * D0
                    + 0.5 * (alpha_t * (1 - ops.exp(-2.0 * h))) * D1
                    + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (sigma_t / sigma_s0 * ops.exp(-h)) * sample
                    + (alpha_t * (1 - ops.exp(-2.0 * h))) * D0
                    + (alpha_t * ((1.0 - ops.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
                    + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
                )
        elif self.config.algorithm_type == "sde-dpmsolver":
            assert noise is not None
            if self.config.solver_type == "midpoint":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * D0
                    - (sigma_t * (ops.exp(h) - 1.0)) * D1
                    + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
                )
            elif self.config.solver_type == "heun":
                x_t = (
                    (alpha_t / alpha_s0) * sample
                    - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * D0
                    - 2.0 * (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
                    + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
                )
        return x_t

    def multistep_dpm_solver_third_order_update(
        self,
        model_output_list: List[ms.Tensor],
        *args,
        sample: ms.Tensor = None,
        **kwargs,
    ) -> ms.Tensor:
        """
        One step for the third-order multistep DPMSolver.

        Args:
            model_output_list (`List[ms.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`ms.Tensor`):
                A current instance of a sample created by diffusion process.

        Returns:
            `ms.Tensor`:
                The sample tensor at the previous timestep.
        """

        timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
        prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 2:
                sample = args[2]
            else:
                raise ValueError(" missing`sample` as a required keyward argument")
        if timestep_list is not None:
            deprecate(
                "timestep_list",
                "1.0.0",
                "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        if prev_timestep is not None:
            deprecate(
                "prev_timestep",
                "1.0.0",
                "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
            )

        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
            self.sigmas[self.step_index - 2],
        )

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

        lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
        lambda_s0 = ops.log(alpha_s0) - ops.log(sigma_s0)
        lambda_s1 = ops.log(alpha_s1) - ops.log(sigma_s1)
        lambda_s2 = ops.log(alpha_s2) - ops.log(sigma_s2)

        m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

        h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
        r0, r1 = h_0 / h, h_1 / h
        D0 = m0
        D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
        D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
        D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
        if self.config.algorithm_type == "dpmsolver++":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            x_t = (
                (sigma_t / sigma_s0) * sample
                - (alpha_t * (ops.exp(-h) - 1.0)) * D0
                + (alpha_t * ((ops.exp(-h) - 1.0) / h + 1.0)) * D1
                - (alpha_t * ((ops.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
            )
        elif self.config.algorithm_type == "dpmsolver":
            # See https://arxiv.org/abs/2206.00927 for detailed derivations
            x_t = (
                (alpha_t / alpha_s0) * sample
                - (sigma_t * (ops.exp(h) - 1.0)) * D0
                - (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
                - (sigma_t * ((ops.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
            )
        return x_t

    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps

        index_candidates_num = (schedule_timesteps == timestep).sum()

        if index_candidates_num == 0:
            step_index = len(self.timesteps) - 1
        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
        else:
            if index_candidates_num > 1:
                pos = 1
            else:
                pos = 0
            step_index = int((schedule_timesteps == timestep).nonzero()[pos])

        return step_index

    def _init_step_index(self, timestep):
        """
        Initialize the step_index counter for the scheduler.
        """

        if self.begin_index is None:
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index

    def step(
        self,
        model_output: ms.Tensor,
        timestep: Union[int, ms.Tensor],
        sample: ms.Tensor,
        generator=None,
        variance_noise: Optional[ms.Tensor] = None,
        return_dict: bool = False,
    ) -> Union[SchedulerOutput, Tuple]:
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DPMSolver.

        Args:
            model_output (`ms.Tensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            sample (`ms.Tensor`):
                A current instance of a sample created by the diffusion process.
            generator (`np.random.Generator`, *optional*):
                A random number generator.
            variance_noise (`ms.Tensor`):
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`LEdits++`].
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.

        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        if self.step_index is None:
            self._init_step_index(timestep)

        # Improve numerical stability for small number of steps
        lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
            self.config.euler_at_final
            or (self.config.lower_order_final and len(self.timesteps) < 15)
            or self.config.final_sigmas_type == "zero"
        )
        lower_order_second = (
            (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
        )

        model_output = self.convert_model_output(model_output, sample=sample)
        for i in range(self.config.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(ms.float32)
        if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
            noise = randn_tensor(model_output.shape, generator=generator, dtype=ms.float32)
        elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
            noise = variance_noise.to(dtype=ms.float32)
        else:
            noise = None

        if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
            prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
        elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
            prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
        else:
            prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)

        if self.lower_order_nums < self.config.solver_order:
            self.lower_order_nums += 1

        # Cast sample back to expected dtype
        prev_sample = prev_sample.to(model_output.dtype)

        # upon completion increase step index by one
        self._step_index += 1

        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

    def scale_model_input(self, sample: ms.Tensor, *args, **kwargs) -> ms.Tensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`ms.Tensor`):
                The input sample.

        Returns:
            `ms.Tensor`:
                A scaled input sample.
        """
        return sample

    def add_noise(
        self,
        original_samples: ms.Tensor,
        noise: ms.Tensor,
        timesteps: ms.Tensor,
    ) -> ms.Tensor:
        broadcast_shape = original_samples.shape
        # Make sure sigmas and timesteps have the same device and dtype as original_samples
        sigmas = self.sigmas.to(dtype=original_samples.dtype)
        schedule_timesteps = self.timesteps

        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
        if self.begin_index is None:
            step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
        elif self.step_index is not None:
            # add_noise is called after first denoising step (for inpainting)
            step_indices = [self.step_index] * timesteps.shape[0]
        else:
            # add noise is called before first denoising step to create initial latent(img2img)
            step_indices = [self.begin_index] * timesteps.shape[0]

        sigma = sigmas[step_indices].flatten()
        # while len(sigma.shape) < len(original_samples.shape):
        #     sigma = sigma.unsqueeze(-1)
        sigma = ops.reshape(sigma, (timesteps.shape[0],) + (1,) * (len(broadcast_shape) - 1))

        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
        return noisy_samples

    def __len__(self):
        return self.config.num_train_timesteps

mindone.diffusers.DPMSolverMultistepScheduler.begin_index property

The index for the first timestep. It should be set from pipeline with set_begin_index method.

mindone.diffusers.DPMSolverMultistepScheduler.step_index property

The index counter for current timestep. It will increase 1 after each scheduler step.

mindone.diffusers.DPMSolverMultistepScheduler.convert_model_output(model_output, *args, sample=None, **kwargs)

Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model.

The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models.

PARAMETER DESCRIPTION
model_output

The direct output from the learned diffusion model.

TYPE: `ms.Tensor`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor` DEFAULT: None

RETURNS DESCRIPTION
Tensor

ms.Tensor: The converted model output.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def convert_model_output(
    self,
    model_output: ms.Tensor,
    *args,
    sample: ms.Tensor = None,
    **kwargs,
) -> ms.Tensor:
    """
    Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
    designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
    integral of the data prediction model.

    <Tip>

    The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
    prediction and data prediction models.

    </Tip>

    Args:
        model_output (`ms.Tensor`):
            The direct output from the learned diffusion model.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.

    Returns:
        `ms.Tensor`:
            The converted model output.
    """
    timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
    sigmas = self.sigmas.to(dtype=sample.dtype)
    if sample is None:
        if len(args) > 1:
            sample = args[1]
        else:
            raise ValueError("missing `sample` as a required keyward argument")
    if timestep is not None:
        deprecate(
            "timesteps",
            "1.0.0",
            "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    # DPM-Solver++ needs to solve an integral of the data prediction model.
    if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
        if self.config.prediction_type == "epsilon":
            # DPM-Solver and DPM-Solver++ only need the "mean" output.
            if self.config.variance_type in ["learned", "learned_range"]:
                model_output = model_output[:, :3]
            sigma = sigmas[self.step_index]
            alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
            x0_pred = (sample - sigma_t * model_output) / alpha_t
        elif self.config.prediction_type == "sample":
            x0_pred = model_output
        elif self.config.prediction_type == "v_prediction":
            sigma = sigmas[self.step_index]
            alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
            x0_pred = alpha_t * sample - sigma_t * model_output
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction` for the DPMSolverMultistepScheduler."
            )

        if self.config.thresholding:
            x0_pred = self._threshold_sample(x0_pred)

        return x0_pred

    # DPM-Solver needs to solve an integral of the noise prediction model.
    elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
        if self.config.prediction_type == "epsilon":
            # DPM-Solver and DPM-Solver++ only need the "mean" output.
            if self.config.variance_type in ["learned", "learned_range"]:
                epsilon = model_output[:, :3]
            else:
                epsilon = model_output
        elif self.config.prediction_type == "sample":
            sigma = sigmas[self.step_index]
            alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
            epsilon = (sample - alpha_t * model_output) / sigma_t
        elif self.config.prediction_type == "v_prediction":
            sigma = sigmas[self.step_index]
            alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
            epsilon = alpha_t * model_output + sigma_t * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction` for the DPMSolverMultistepScheduler."
            )

        if self.config.thresholding:
            sigma = sigmas[self.step_index]
            alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
            x0_pred = (sample - sigma_t * epsilon) / alpha_t
            x0_pred = self._threshold_sample(x0_pred)
            epsilon = (sample - alpha_t * x0_pred) / sigma_t

        return epsilon

mindone.diffusers.DPMSolverMultistepScheduler.dpm_solver_first_order_update(model_output, *args, sample=None, noise=None, **kwargs)

One step for the first-order DPMSolver (equivalent to DDIM).

PARAMETER DESCRIPTION
model_output

The direct output from the learned diffusion model.

TYPE: `ms.Tensor`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor` DEFAULT: None

RETURNS DESCRIPTION
Tensor

ms.Tensor: The sample tensor at the previous timestep.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
def dpm_solver_first_order_update(
    self,
    model_output: ms.Tensor,
    *args,
    sample: ms.Tensor = None,
    noise: Optional[ms.Tensor] = None,
    **kwargs,
) -> ms.Tensor:
    """
    One step for the first-order DPMSolver (equivalent to DDIM).

    Args:
        model_output (`ms.Tensor`):
            The direct output from the learned diffusion model.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.

    Returns:
        `ms.Tensor`:
            The sample tensor at the previous timestep.
    """
    timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
    prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
    if sample is None:
        if len(args) > 2:
            sample = args[2]
        else:
            raise ValueError(" missing `sample` as a required keyward argument")
    if timestep is not None:
        deprecate(
            "timesteps",
            "1.0.0",
            "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    if prev_timestep is not None:
        deprecate(
            "prev_timestep",
            "1.0.0",
            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
    alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
    lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
    lambda_s = ops.log(alpha_s) - ops.log(sigma_s)

    h = lambda_t - lambda_s
    if self.config.algorithm_type == "dpmsolver++":
        x_t = (sigma_t / sigma_s) * sample - (alpha_t * (ops.exp(-h) - 1.0)) * model_output
    elif self.config.algorithm_type == "dpmsolver":
        x_t = (alpha_t / alpha_s) * sample - (sigma_t * (ops.exp(h) - 1.0)) * model_output
    elif self.config.algorithm_type == "sde-dpmsolver++":
        assert noise is not None
        x_t = (
            (sigma_t / sigma_s * ops.exp(-h)) * sample
            + (alpha_t * (1 - ops.exp(-2.0 * h))) * model_output
            + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
        )
    elif self.config.algorithm_type == "sde-dpmsolver":
        assert noise is not None
        x_t = (
            (alpha_t / alpha_s) * sample
            - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * model_output
            + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
        )
    return x_t

mindone.diffusers.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update(model_output_list, *args, sample=None, noise=None, **kwargs)

One step for the second-order multistep DPMSolver.

PARAMETER DESCRIPTION
model_output_list

The direct outputs from learned diffusion model at current and latter timesteps.

TYPE: `List[ms.Tensor]`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor` DEFAULT: None

RETURNS DESCRIPTION
Tensor

ms.Tensor: The sample tensor at the previous timestep.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def multistep_dpm_solver_second_order_update(
    self,
    model_output_list: List[ms.Tensor],
    *args,
    sample: ms.Tensor = None,
    noise: Optional[ms.Tensor] = None,
    **kwargs,
) -> ms.Tensor:
    """
    One step for the second-order multistep DPMSolver.

    Args:
        model_output_list (`List[ms.Tensor]`):
            The direct outputs from learned diffusion model at current and latter timesteps.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.

    Returns:
        `ms.Tensor`:
            The sample tensor at the previous timestep.
    """
    timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
    prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
    if sample is None:
        if len(args) > 2:
            sample = args[2]
        else:
            raise ValueError(" missing `sample` as a required keyward argument")
    if timestep_list is not None:
        deprecate(
            "timestep_list",
            "1.0.0",
            "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    if prev_timestep is not None:
        deprecate(
            "prev_timestep",
            "1.0.0",
            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    sigma_t, sigma_s0, sigma_s1 = (
        self.sigmas[self.step_index + 1],
        self.sigmas[self.step_index],
        self.sigmas[self.step_index - 1],
    )

    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
    alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

    lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
    lambda_s0 = ops.log(alpha_s0) - ops.log(sigma_s0)
    lambda_s1 = ops.log(alpha_s1) - ops.log(sigma_s1)

    m0, m1 = model_output_list[-1], model_output_list[-2]

    h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
    r0 = h_0 / h
    D0, D1 = m0, (1.0 / r0) * (m0 - m1)
    if self.config.algorithm_type == "dpmsolver++":
        # See https://arxiv.org/abs/2211.01095 for detailed derivations
        if self.config.solver_type == "midpoint":
            x_t = (
                (sigma_t / sigma_s0) * sample
                - (alpha_t * (ops.exp(-h) - 1.0)) * D0
                - 0.5 * (alpha_t * (ops.exp(-h) - 1.0)) * D1
            )
        elif self.config.solver_type == "heun":
            x_t = (
                (sigma_t / sigma_s0) * sample
                - (alpha_t * (ops.exp(-h) - 1.0)) * D0
                + (alpha_t * ((ops.exp(-h) - 1.0) / h + 1.0)) * D1
            )
    elif self.config.algorithm_type == "dpmsolver":
        # See https://arxiv.org/abs/2206.00927 for detailed derivations
        if self.config.solver_type == "midpoint":
            x_t = (
                (alpha_t / alpha_s0) * sample
                - (sigma_t * (ops.exp(h) - 1.0)) * D0
                - 0.5 * (sigma_t * (ops.exp(h) - 1.0)) * D1
            )
        elif self.config.solver_type == "heun":
            x_t = (
                (alpha_t / alpha_s0) * sample
                - (sigma_t * (ops.exp(h) - 1.0)) * D0
                - (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
            )
    elif self.config.algorithm_type == "sde-dpmsolver++":
        assert noise is not None
        if self.config.solver_type == "midpoint":
            x_t = (
                (sigma_t / sigma_s0 * ops.exp(-h)) * sample
                + (alpha_t * (1 - ops.exp(-2.0 * h))) * D0
                + 0.5 * (alpha_t * (1 - ops.exp(-2.0 * h))) * D1
                + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
            )
        elif self.config.solver_type == "heun":
            x_t = (
                (sigma_t / sigma_s0 * ops.exp(-h)) * sample
                + (alpha_t * (1 - ops.exp(-2.0 * h))) * D0
                + (alpha_t * ((1.0 - ops.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
                + sigma_t * ops.sqrt(1.0 - ops.exp(-2 * h)) * noise
            )
    elif self.config.algorithm_type == "sde-dpmsolver":
        assert noise is not None
        if self.config.solver_type == "midpoint":
            x_t = (
                (alpha_t / alpha_s0) * sample
                - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * D0
                - (sigma_t * (ops.exp(h) - 1.0)) * D1
                + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
            )
        elif self.config.solver_type == "heun":
            x_t = (
                (alpha_t / alpha_s0) * sample
                - 2.0 * (sigma_t * (ops.exp(h) - 1.0)) * D0
                - 2.0 * (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
                + sigma_t * ops.sqrt(ops.exp(2 * h) - 1.0) * noise
            )
    return x_t

mindone.diffusers.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update(model_output_list, *args, sample=None, **kwargs)

One step for the third-order multistep DPMSolver.

PARAMETER DESCRIPTION
model_output_list

The direct outputs from learned diffusion model at current and latter timesteps.

TYPE: `List[ms.Tensor]`

sample

A current instance of a sample created by diffusion process.

TYPE: `ms.Tensor` DEFAULT: None

RETURNS DESCRIPTION
Tensor

ms.Tensor: The sample tensor at the previous timestep.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def multistep_dpm_solver_third_order_update(
    self,
    model_output_list: List[ms.Tensor],
    *args,
    sample: ms.Tensor = None,
    **kwargs,
) -> ms.Tensor:
    """
    One step for the third-order multistep DPMSolver.

    Args:
        model_output_list (`List[ms.Tensor]`):
            The direct outputs from learned diffusion model at current and latter timesteps.
        sample (`ms.Tensor`):
            A current instance of a sample created by diffusion process.

    Returns:
        `ms.Tensor`:
            The sample tensor at the previous timestep.
    """

    timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
    prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
    if sample is None:
        if len(args) > 2:
            sample = args[2]
        else:
            raise ValueError(" missing`sample` as a required keyward argument")
    if timestep_list is not None:
        deprecate(
            "timestep_list",
            "1.0.0",
            "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    if prev_timestep is not None:
        deprecate(
            "prev_timestep",
            "1.0.0",
            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
        )

    sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
        self.sigmas[self.step_index + 1],
        self.sigmas[self.step_index],
        self.sigmas[self.step_index - 1],
        self.sigmas[self.step_index - 2],
    )

    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
    alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
    alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

    lambda_t = ops.log(alpha_t) - ops.log(sigma_t)
    lambda_s0 = ops.log(alpha_s0) - ops.log(sigma_s0)
    lambda_s1 = ops.log(alpha_s1) - ops.log(sigma_s1)
    lambda_s2 = ops.log(alpha_s2) - ops.log(sigma_s2)

    m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

    h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
    r0, r1 = h_0 / h, h_1 / h
    D0 = m0
    D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
    D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
    D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
    if self.config.algorithm_type == "dpmsolver++":
        # See https://arxiv.org/abs/2206.00927 for detailed derivations
        x_t = (
            (sigma_t / sigma_s0) * sample
            - (alpha_t * (ops.exp(-h) - 1.0)) * D0
            + (alpha_t * ((ops.exp(-h) - 1.0) / h + 1.0)) * D1
            - (alpha_t * ((ops.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
        )
    elif self.config.algorithm_type == "dpmsolver":
        # See https://arxiv.org/abs/2206.00927 for detailed derivations
        x_t = (
            (alpha_t / alpha_s0) * sample
            - (sigma_t * (ops.exp(h) - 1.0)) * D0
            - (sigma_t * ((ops.exp(h) - 1.0) / h - 1.0)) * D1
            - (sigma_t * ((ops.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
        )
    return x_t

mindone.diffusers.DPMSolverMultistepScheduler.scale_model_input(sample, *args, **kwargs)

Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep.

PARAMETER DESCRIPTION
sample

The input sample.

TYPE: `ms.Tensor`

RETURNS DESCRIPTION
Tensor

ms.Tensor: A scaled input sample.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def scale_model_input(self, sample: ms.Tensor, *args, **kwargs) -> ms.Tensor:
    """
    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
    current timestep.

    Args:
        sample (`ms.Tensor`):
            The input sample.

    Returns:
        `ms.Tensor`:
            A scaled input sample.
    """
    return sample

mindone.diffusers.DPMSolverMultistepScheduler.set_begin_index(begin_index=0)

Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

PARAMETER DESCRIPTION
begin_index

The begin index for the scheduler.

TYPE: `int` DEFAULT: 0

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
302
303
304
305
306
307
308
309
310
def set_begin_index(self, begin_index: int = 0):
    """
    Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

    Args:
        begin_index (`int`):
            The begin index for the scheduler.
    """
    self._begin_index = begin_index

mindone.diffusers.DPMSolverMultistepScheduler.set_timesteps(num_inference_steps=None, timesteps=None)

Sets the discrete timesteps used for the diffusion chain (to be run before inference).

PARAMETER DESCRIPTION
num_inference_steps

The number of diffusion steps used when generating samples with a pre-trained model.

TYPE: `int` DEFAULT: None

timesteps

Custom timesteps used to support arbitrary timesteps schedule. If None, timesteps will be generated based on the timestep_spacing attribute. If timesteps is passed, num_inference_steps and sigmas must be None, and timestep_spacing attribute will be ignored.

TYPE: `List[int]`, *optional* DEFAULT: None

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def set_timesteps(self, num_inference_steps: int = None, timesteps: Optional[List[int]] = None):
    """
    Sets the discrete timesteps used for the diffusion chain (to be run before inference).

    Args:
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
            based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
            must be `None`, and `timestep_spacing` attribute will be ignored.
    """
    if num_inference_steps is None and timesteps is None:
        raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
    if num_inference_steps is not None and timesteps is not None:
        raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
    if timesteps is not None and self.config.use_karras_sigmas:
        raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
    if timesteps is not None and self.config.use_lu_lambdas:
        raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")

    if timesteps is not None:
        timesteps = np.array(timesteps).astype(np.int64)
    else:
        # Clipping the minimum of all lambda(t) for numerical stability.
        # This is critical for cosine (squaredcos_cap_v2) noise schedule.
        clipped_idx = ms.tensor(
            np.searchsorted(ops.flip(self.lambda_t, [0]).asnumpy(), self.config.lambda_min_clipped), dtype=ms.int64
        )
        last_timestep = ((self.config.num_train_timesteps - clipped_idx).asnumpy()).item()

        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
        if self.config.timestep_spacing == "linspace":
            timesteps = (
                np.linspace(0, last_timestep - 1, num_inference_steps + 1)
                .round()[::-1][:-1]
                .copy()
                .astype(np.int64)
            )
        elif self.config.timestep_spacing == "leading":
            step_ratio = last_timestep // (num_inference_steps + 1)
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (
                (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
            )
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
            )

    sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).asnumpy()
    log_sigmas = np.log(sigmas)

    if self.config.use_karras_sigmas:
        sigmas = np.flip(sigmas).copy()
        sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
        timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
    elif self.config.use_lu_lambdas:
        lambdas = np.flip(log_sigmas.copy())
        lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
        sigmas = np.exp(lambdas)
        timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
    else:
        sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

    if self.config.final_sigmas_type == "sigma_min":
        sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
    elif self.config.final_sigmas_type == "zero":
        sigma_last = 0
    else:
        raise ValueError(
            f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
        )

    sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

    self.sigmas = ms.Tensor(sigmas)
    self.timesteps = ms.tensor(timesteps, dtype=ms.int64)

    self.num_inference_steps = len(timesteps)

    self.model_outputs = [
        None,
    ] * self.config.solver_order
    self.lower_order_nums = 0

    # add an index counter for schedulers that allow duplicated timesteps
    self._step_index = None
    self._begin_index = None

mindone.diffusers.DPMSolverMultistepScheduler.step(model_output, timestep, sample, generator=None, variance_noise=None, return_dict=False)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver.

PARAMETER DESCRIPTION
model_output

The direct output from learned diffusion model.

TYPE: `ms.Tensor`

timestep

The current discrete timestep in the diffusion chain.

TYPE: `int`

sample

A current instance of a sample created by the diffusion process.

TYPE: `ms.Tensor`

generator

A random number generator.

TYPE: `np.random.Generator`, *optional* DEFAULT: None

variance_noise

Alternative to generating noise with generator by directly providing the noise for the variance itself. Useful for methods such as [LEdits++].

TYPE: `ms.Tensor` DEFAULT: None

return_dict

Whether or not to return a [~schedulers.scheduling_utils.SchedulerOutput] or tuple.

TYPE: `bool` DEFAULT: False

RETURNS DESCRIPTION
Union[SchedulerOutput, Tuple]

[~schedulers.scheduling_utils.SchedulerOutput] or tuple: If return_dict is True, [~schedulers.scheduling_utils.SchedulerOutput] is returned, otherwise a tuple is returned where the first element is the sample tensor.

Source code in mindone/diffusers/schedulers/scheduling_dpmsolver_multistep.py
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
def step(
    self,
    model_output: ms.Tensor,
    timestep: Union[int, ms.Tensor],
    sample: ms.Tensor,
    generator=None,
    variance_noise: Optional[ms.Tensor] = None,
    return_dict: bool = False,
) -> Union[SchedulerOutput, Tuple]:
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
    the multistep DPMSolver.

    Args:
        model_output (`ms.Tensor`):
            The direct output from learned diffusion model.
        timestep (`int`):
            The current discrete timestep in the diffusion chain.
        sample (`ms.Tensor`):
            A current instance of a sample created by the diffusion process.
        generator (`np.random.Generator`, *optional*):
            A random number generator.
        variance_noise (`ms.Tensor`):
            Alternative to generating noise with `generator` by directly providing the noise for the variance
            itself. Useful for methods such as [`LEdits++`].
        return_dict (`bool`):
            Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

    Returns:
        [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
            If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
            tuple is returned where the first element is the sample tensor.

    """
    if self.num_inference_steps is None:
        raise ValueError(
            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
        )

    if self.step_index is None:
        self._init_step_index(timestep)

    # Improve numerical stability for small number of steps
    lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
        self.config.euler_at_final
        or (self.config.lower_order_final and len(self.timesteps) < 15)
        or self.config.final_sigmas_type == "zero"
    )
    lower_order_second = (
        (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
    )

    model_output = self.convert_model_output(model_output, sample=sample)
    for i in range(self.config.solver_order - 1):
        self.model_outputs[i] = self.model_outputs[i + 1]
    self.model_outputs[-1] = model_output

    # Upcast to avoid precision issues when computing prev_sample
    sample = sample.to(ms.float32)
    if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
        noise = randn_tensor(model_output.shape, generator=generator, dtype=ms.float32)
    elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
        noise = variance_noise.to(dtype=ms.float32)
    else:
        noise = None

    if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
        prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
    elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
        prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
    else:
        prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)

    if self.lower_order_nums < self.config.solver_order:
        self.lower_order_nums += 1

    # Cast sample back to expected dtype
    prev_sample = prev_sample.to(model_output.dtype)

    # upon completion increase step index by one
    self._step_index += 1

    if not return_dict:
        return (prev_sample,)

    return SchedulerOutput(prev_sample=prev_sample)