Skip to content

Optimizer

Optimizer Factory

mindcv.optim.optim_factory.create_optimizer(model_or_params, opt='adam', lr=0.001, weight_decay=0, momentum=0.9, nesterov=False, weight_decay_filter='disable', layer_decay=None, loss_scale=1.0, schedule_decay=0.004, checkpoint_path='', eps=1e-10, **kwargs)

Creates optimizer by name.

PARAMETER DESCRIPTION
model_or_params

network or network parameters. Union[list[Parameter],list[dict], nn.Cell], which must be the list of parameters or list of dicts or nn.Cell. When the list element is a dictionary, the key of the dictionary can be "params", "lr", "weight_decay","grad_centralization" and "order_params".

opt

wrapped optimizer. You could choose like 'sgd', 'nesterov', 'momentum', 'adam', 'adamw', 'lion', 'rmsprop', 'adagrad', 'lamb'. 'adam' is the default choose for convolution-based networks. 'adamw' is recommended for ViT-based networks. Default: 'adam'.

TYPE: str DEFAULT: 'adam'

lr

learning rate: float or lr scheduler. Fixed and dynamic learning rate are supported. Default: 1e-3.

TYPE: Optional[float] DEFAULT: 0.001

weight_decay

weight decay factor. It should be noted that weight decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule to get the weight decay value of current step. Default: 0.

TYPE: float DEFAULT: 0

momentum

momentum if the optimizer supports. Default: 0.9.

TYPE: float DEFAULT: 0.9

nesterov

Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. Default: False.

TYPE: bool DEFAULT: False

weight_decay_filter

filters to filter parameters from weight_decay. - "disable": No parameters to filter. - "auto": We do not apply weight decay filtering to any parameters. However, MindSpore currently automatically filters the parameters of Norm layer from weight decay. - "norm_and_bias": Filter the paramters of Norm layer and Bias from weight decay.

TYPE: str DEFAULT: 'disable'

layer_decay

for apply layer-wise learning rate decay.

TYPE: Optional[float] DEFAULT: None

loss_scale

A floating point value for the loss scale, which must be larger than 0.0. Default: 1.0.

TYPE: float DEFAULT: 1.0

RETURNS DESCRIPTION

Optimizer object

Source code in mindcv/optim/optim_factory.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def create_optimizer(
    model_or_params,
    opt: str = "adam",
    lr: Optional[float] = 1e-3,
    weight_decay: float = 0,
    momentum: float = 0.9,
    nesterov: bool = False,
    weight_decay_filter: str = "disable",
    layer_decay: Optional[float] = None,
    loss_scale: float = 1.0,
    schedule_decay: float = 4e-3,
    checkpoint_path: str = "",
    eps: float = 1e-10,
    **kwargs,
):
    r"""Creates optimizer by name.

    Args:
        model_or_params: network or network parameters. Union[list[Parameter],list[dict], nn.Cell], which must be
            the list of parameters or list of dicts or nn.Cell. When the list element is a dictionary, the key of
            the dictionary can be "params", "lr", "weight_decay","grad_centralization" and "order_params".
        opt: wrapped optimizer. You could choose like 'sgd', 'nesterov', 'momentum', 'adam', 'adamw', 'lion',
            'rmsprop', 'adagrad', 'lamb'. 'adam' is the default choose for convolution-based networks.
            'adamw' is recommended for ViT-based networks. Default: 'adam'.
        lr: learning rate: float or lr scheduler. Fixed and dynamic learning rate are supported. Default: 1e-3.
        weight_decay: weight decay factor. It should be noted that weight decay can be a constant value or a Cell.
            It is a Cell only when dynamic weight decay is applied. Dynamic weight decay is similar to
            dynamic learning rate, users need to customize a weight decay schedule only with global step as input,
            and during training, the optimizer calls the instance of WeightDecaySchedule to get the weight decay value
            of current step. Default: 0.
        momentum: momentum if the optimizer supports. Default: 0.9.
        nesterov: Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. Default: False.
        weight_decay_filter: filters to filter parameters from weight_decay.
            - "disable": No parameters to filter.
            - "auto": We do not apply weight decay filtering to any parameters. However, MindSpore currently
                    automatically filters the parameters of Norm layer from weight decay.
            - "norm_and_bias": Filter the paramters of Norm layer and Bias from weight decay.
        layer_decay: for apply layer-wise learning rate decay.
        loss_scale: A floating point value for the loss scale, which must be larger than 0.0. Default: 1.0.

    Returns:
        Optimizer object
    """

    no_weight_decay = {}
    if isinstance(model_or_params, nn.Cell):
        # a model was passed in, extract parameters and add weight decays to appropriate layers
        if hasattr(model_or_params, "no_weight_decay"):
            no_weight_decay = model_or_params.no_weight_decay()
        params = model_or_params.trainable_params()

    else:
        params = model_or_params

    if weight_decay_filter == "auto":
        _logger.warning(
            "You are using AUTO weight decay filter, which means the weight decay filter isn't explicitly pass in "
            "when creating an mindspore.nn.Optimizer instance. "
            "NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
        )
    elif layer_decay is not None and isinstance(model_or_params, nn.Cell):
        params = param_groups_layer_decay(
            model_or_params,
            lr=lr,
            weight_decay=weight_decay,
            layer_decay=layer_decay,
            no_weight_decay_list=no_weight_decay,
        )
        weight_decay = 0.0
    elif weight_decay_filter == "disable" or "norm_and_bias":
        params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay)
        weight_decay = 0.0
    else:
        raise ValueError(
            f"weight decay filter only support ['disable', 'auto', 'norm_and_bias'], but got{weight_decay_filter}."
        )

    opt = opt.lower()
    opt_args = dict(**kwargs)
    # if lr is not None:
    #    opt_args.setdefault('lr', lr)

    # non-adaptive: SGD, momentum, and nesterov
    if opt == "sgd":
        # note: nn.Momentum may perform better if momentum > 0.
        optimizer = nn.SGD(
            params=params,
            learning_rate=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            nesterov=nesterov,
            loss_scale=loss_scale,
            **opt_args,
        )
    elif opt in ["momentum", "nesterov"]:
        optimizer = nn.Momentum(
            params=params,
            learning_rate=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            use_nesterov=nesterov,
            loss_scale=loss_scale,
        )
    # adaptive
    elif opt == "adam":
        optimizer = nn.Adam(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            use_nesterov=nesterov,
            **opt_args,
        )
    elif opt == "adamw":
        optimizer = AdamW(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            **opt_args,
        )
    elif opt == "lion":
        optimizer = Lion(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            **opt_args,
        )
    elif opt == "nadam":
        optimizer = NAdam(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            schedule_decay=schedule_decay,
            **opt_args,
        )
    elif opt == "adan":
        optimizer = Adan(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            **opt_args,
        )
    elif opt == "rmsprop":
        optimizer = nn.RMSProp(
            params=params,
            learning_rate=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            epsilon=eps,
            **opt_args,
        )
    elif opt == "adagrad":
        optimizer = nn.Adagrad(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            loss_scale=loss_scale,
            **opt_args,
        )
    elif opt == "lamb":
        assert loss_scale == 1.0, "Loss scaler is not supported by Lamb optimizer"
        optimizer = nn.Lamb(
            params=params,
            learning_rate=lr,
            weight_decay=weight_decay,
            **opt_args,
        )
    else:
        raise ValueError(f"Invalid optimizer: {opt}")

    if os.path.exists(checkpoint_path):
        param_dict = load_checkpoint(checkpoint_path)
        load_param_into_net(optimizer, param_dict)

    return optimizer

AdamW

mindcv.optim.adamw.AdamW

Bases: Optimizer

Implements the gradient clipping by norm for a AdamWeightDecay optimizer.

Source code in mindcv/optim/adamw.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class AdamW(Optimizer):
    """
    Implements the gradient clipping by norm for a AdamWeightDecay optimizer.
    """

    @opt_init_args_register
    def __init__(
        self,
        params,
        learning_rate=1e-3,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.0,
        loss_scale=1.0,
        clip=False,
    ):
        super().__init__(learning_rate, params, weight_decay)
        _check_param_value(beta1, beta2, eps, self.cls_name)
        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
        self.eps = Tensor(np.array([eps]).astype(np.float32))
        self.moments1 = self.parameters.clone(prefix="adam_m", init="zeros")
        self.moments2 = self.parameters.clone(prefix="adam_v", init="zeros")
        self.hyper_map = ops.HyperMap()
        self.beta1_power = Parameter(initializer(1, [1], ms.float32), name="beta1_power")
        self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")

        self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
        self.clip = clip

    def get_lr(self):
        """
        The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
        on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            if self.is_group_lr:
                lr = ()
                for learning_rate in self.learning_rate:
                    current_dynamic_lr = learning_rate(self.global_step).reshape(())
                    lr += (current_dynamic_lr,)
            else:
                lr = self.learning_rate(self.global_step).reshape(())
        if self._is_dynamic_lr_or_weight_decay():
            self.assignadd(self.global_step, self.global_step_increase_tensor)
        return lr

    @jit
    def construct(self, gradients):
        lr = self.get_lr()
        gradients = scale_grad(gradients, self.reciprocal_scale)
        if self.clip:
            gradients = ops.clip_by_global_norm(gradients, 5.0, None)

        beta1_power = self.beta1_power * self.beta1
        self.beta1_power = beta1_power
        beta2_power = self.beta2_power * self.beta2
        self.beta2_power = beta2_power

        if self.is_group:
            if self.is_group_lr:
                optim_result = self.hyper_map(
                    ops.partial(_adam_opt, beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
                    lr,
                    self.weight_decay,
                    self.parameters,
                    self.moments1,
                    self.moments2,
                    gradients,
                    self.decay_flags,
                    self.optim_filter,
                )
            else:
                optim_result = self.hyper_map(
                    ops.partial(_adam_opt, beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
                    self.weight_decay,
                    self.parameters,
                    self.moments1,
                    self.moments2,
                    gradients,
                    self.decay_flags,
                    self.optim_filter,
                )
        else:
            optim_result = self.hyper_map(
                ops.partial(
                    _adam_opt, beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr, self.weight_decay
                ),
                self.parameters,
                self.moments1,
                self.moments2,
                gradients,
                self.decay_flags,
                self.optim_filter,
            )
        if self.use_parallel:
            self.broadcast_params(optim_result)
        return optim_result

mindcv.optim.adamw.AdamW.get_lr()

The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based on :class:mindspore.nn.Optimizer can also call this interface before updating the parameters.

RETURNS DESCRIPTION

float, the learning rate of current step.

Source code in mindcv/optim/adamw.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def get_lr(self):
    """
    The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
    on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

    Returns:
        float, the learning rate of current step.
    """
    lr = self.learning_rate
    if self.dynamic_lr:
        if self.is_group_lr:
            lr = ()
            for learning_rate in self.learning_rate:
                current_dynamic_lr = learning_rate(self.global_step).reshape(())
                lr += (current_dynamic_lr,)
        else:
            lr = self.learning_rate(self.global_step).reshape(())
    if self._is_dynamic_lr_or_weight_decay():
        self.assignadd(self.global_step, self.global_step_increase_tensor)
    return lr

Adan

mindcv.optim.adan.Adan

Bases: Optimizer

The Adan (ADAptive Nesterov momentum algorithm) Optimizer from https://arxiv.org/abs/2208.06677

Note: it is an experimental version.

Source code in mindcv/optim/adan.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class Adan(Optimizer):
    """
    The Adan (ADAptive Nesterov momentum algorithm) Optimizer from https://arxiv.org/abs/2208.06677

    Note: it is an experimental version.
    """

    @opt_init_args_register
    def __init__(
        self,
        params,
        learning_rate=1e-3,
        beta1=0.98,
        beta2=0.92,
        beta3=0.99,
        eps=1e-8,
        use_locking=False,
        weight_decay=0.0,
        loss_scale=1.0,
    ):
        super().__init__(
            learning_rate, params, weight_decay=weight_decay, loss_scale=loss_scale
        )  # Optimized inherit weight decay is bloaked. weight decay is computed in this py.

        _check_param_value(beta1, beta2, eps, self.cls_name)
        assert isinstance(use_locking, bool), f"For {self.cls_name}, use_looking should be bool"

        self.beta1 = Tensor(beta1, mstype.float32)
        self.beta2 = Tensor(beta2, mstype.float32)
        self.beta3 = Tensor(beta3, mstype.float32)

        self.eps = Tensor(eps, mstype.float32)
        self.use_locking = use_locking
        self.moment1 = self._parameters.clone(prefix="moment1", init="zeros")  # m
        self.moment2 = self._parameters.clone(prefix="moment2", init="zeros")  # v
        self.moment3 = self._parameters.clone(prefix="moment3", init="zeros")  # n
        self.prev_gradient = self._parameters.clone(prefix="prev_gradient", init="zeros")

        self.weight_decay = Tensor(weight_decay, mstype.float32)

    def get_lr(self):
        """
        The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
        on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            if self.is_group_lr:
                lr = ()
                for learning_rate in self.learning_rate:
                    current_dynamic_lr = learning_rate(self.global_step).reshape(())
                    lr += (current_dynamic_lr,)
            else:
                lr = self.learning_rate(self.global_step).reshape(())
        if self._is_dynamic_lr_or_weight_decay():
            self.assignadd(self.global_step, self.global_step_increase_tensor)
        return lr

    @jit
    def construct(self, gradients):
        params = self._parameters
        moment1 = self.moment1
        moment2 = self.moment2
        moment3 = self.moment3

        gradients = self.flatten_gradients(gradients)
        gradients = self.gradients_centralization(gradients)
        gradients = self.scale_grad(gradients)
        gradients = self._grad_sparse_indices_deduplicate(gradients)
        lr = self.get_lr()

        # TODO: currently not support dist
        success = self.map_(
            ops.partial(_adan_opt, self.beta1, self.beta2, self.beta3, self.eps, lr, self.weight_decay),
            params,
            moment1,
            moment2,
            moment3,
            gradients,
            self.prev_gradient,
        )

        return success

    @Optimizer.target.setter
    def target(self, value):
        """
        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
        optimizer operation.
        """
        self._set_base_target(value)

mindcv.optim.adan.Adan.get_lr()

The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based on :class:mindspore.nn.Optimizer can also call this interface before updating the parameters.

RETURNS DESCRIPTION

float, the learning rate of current step.

Source code in mindcv/optim/adan.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def get_lr(self):
    """
    The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
    on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

    Returns:
        float, the learning rate of current step.
    """
    lr = self.learning_rate
    if self.dynamic_lr:
        if self.is_group_lr:
            lr = ()
            for learning_rate in self.learning_rate:
                current_dynamic_lr = learning_rate(self.global_step).reshape(())
                lr += (current_dynamic_lr,)
        else:
            lr = self.learning_rate(self.global_step).reshape(())
    if self._is_dynamic_lr_or_weight_decay():
        self.assignadd(self.global_step, self.global_step_increase_tensor)
    return lr

mindcv.optim.adan.Adan.target(value)

If the input value is set to "CPU", the parameters will be updated on the host using the Fused optimizer operation.

Source code in mindcv/optim/adan.py
199
200
201
202
203
204
205
@Optimizer.target.setter
def target(self, value):
    """
    If the input value is set to "CPU", the parameters will be updated on the host using the Fused
    optimizer operation.
    """
    self._set_base_target(value)

Lion

mindcv.optim.lion.Lion

Bases: Optimizer

Implementation of Lion optimizer from paper 'https://arxiv.org/abs/2302.06675'. Additionally, this implementation is with gradient clipping.

Notes: lr is usually 3-10x smaller than adamw. weight decay is usually 3-10x larger than adamw.

Source code in mindcv/optim/lion.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class Lion(Optimizer):
    """
    Implementation of Lion optimizer from paper 'https://arxiv.org/abs/2302.06675'.
    Additionally, this implementation is with gradient clipping.

    Notes:
    lr is usually 3-10x smaller than adamw.
    weight decay is usually 3-10x larger than adamw.
    """

    @opt_init_args_register
    def __init__(
        self,
        params,
        learning_rate=2e-4,
        beta1=0.9,
        beta2=0.99,
        weight_decay=0.0,
        loss_scale=1.0,
        clip=False,
    ):
        super().__init__(learning_rate, params, weight_decay)
        _check_param_value(beta1, beta2, self.cls_name)
        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
        self.moments1 = self.parameters.clone(prefix="lion_m", init="zeros")
        self.hyper_map = ops.HyperMap()
        self.beta1_power = Parameter(initializer(1, [1], ms.float32), name="beta1_power")
        self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")

        self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
        self.clip = clip

    def get_lr(self):
        """
        The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
        on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            if self.is_group_lr:
                lr = ()
                for learning_rate in self.learning_rate:
                    current_dynamic_lr = learning_rate(self.global_step).reshape(())
                    lr += (current_dynamic_lr,)
            else:
                lr = self.learning_rate(self.global_step).reshape(())
        if self._is_dynamic_lr_or_weight_decay():
            self.assignadd(self.global_step, self.global_step_increase_tensor)
        return lr

    @jit
    def construct(self, gradients):
        lr = self.get_lr()
        gradients = scale_grad(gradients, self.reciprocal_scale)
        if self.clip:
            gradients = ops.clip_by_global_norm(gradients, 5.0, None)

        beta1_power = self.beta1_power * self.beta1
        self.beta1_power = beta1_power
        beta2_power = self.beta2_power * self.beta2
        self.beta2_power = beta2_power

        if self.is_group:
            if self.is_group_lr:
                optim_result = self.hyper_map(
                    ops.partial(_lion_opt, beta1_power, beta2_power, self.beta1, self.beta2),
                    lr,
                    self.weight_decay,
                    self.parameters,
                    self.moments1,
                    gradients,
                    self.decay_flags,
                    self.optim_filter,
                )
            else:
                optim_result = self.hyper_map(
                    ops.partial(_lion_opt, beta1_power, beta2_power, self.beta1, self.beta2, lr),
                    self.weight_decay,
                    self.parameters,
                    self.moments1,
                    gradients,
                    self.decay_flags,
                    self.optim_filter,
                )
        else:
            optim_result = self.hyper_map(
                ops.partial(_lion_opt, beta1_power, beta2_power, self.beta1, self.beta2, lr, self.weight_decay),
                self.parameters,
                self.moments1,
                gradients,
                self.decay_flags,
                self.optim_filter,
            )
        if self.use_parallel:
            self.broadcast_params(optim_result)
        return optim_result

mindcv.optim.lion.Lion.get_lr()

The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based on :class:mindspore.nn.Optimizer can also call this interface before updating the parameters.

RETURNS DESCRIPTION

float, the learning rate of current step.

Source code in mindcv/optim/lion.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def get_lr(self):
    """
    The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
    on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

    Returns:
        float, the learning rate of current step.
    """
    lr = self.learning_rate
    if self.dynamic_lr:
        if self.is_group_lr:
            lr = ()
            for learning_rate in self.learning_rate:
                current_dynamic_lr = learning_rate(self.global_step).reshape(())
                lr += (current_dynamic_lr,)
        else:
            lr = self.learning_rate(self.global_step).reshape(())
    if self._is_dynamic_lr_or_weight_decay():
        self.assignadd(self.global_step, self.global_step_increase_tensor)
    return lr

NAdam

mindcv.optim.nadam.NAdam

Bases: Optimizer

Implements NAdam algorithm (a variant of Adam based on Nesterov momentum).

Source code in mindcv/optim/nadam.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class NAdam(Optimizer):
    """
    Implements NAdam algorithm (a variant of Adam based on Nesterov momentum).
    """

    @opt_init_args_register
    def __init__(
        self,
        params,
        learning_rate=2e-3,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.0,
        loss_scale=1.0,
        schedule_decay=4e-3,
    ):
        super().__init__(learning_rate, params, weight_decay, loss_scale)
        _check_param_value(beta1, beta2, eps, self.cls_name)
        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
        self.eps = Tensor(np.array([eps]).astype(np.float32))
        self.moments1 = self.parameters.clone(prefix="nadam_m", init="zeros")
        self.moments2 = self.parameters.clone(prefix="nadam_v", init="zeros")
        self.schedule_decay = Tensor(np.array([schedule_decay]).astype(np.float32))
        self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule")
        self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")

    def get_lr(self):
        """
        The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
        on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            if self.is_group_lr:
                lr = ()
                for learning_rate in self.learning_rate:
                    current_dynamic_lr = learning_rate(self.global_step).reshape(())
                    lr += (current_dynamic_lr,)
            else:
                lr = self.learning_rate(self.global_step).reshape(())
        if self._is_dynamic_lr_or_weight_decay():
            self.assignadd(self.global_step, self.global_step_increase_tensor)
        return lr

    @jit
    def construct(self, gradients):
        lr = self.get_lr()
        params = self.parameters
        step = self.global_step + _scaler_one
        gradients = self.decay_weight(gradients)
        mu = self.beta1 * (
            _scaler_one - Tensor(0.5, ms.float32) * ops.pow(Tensor(0.96, ms.float32), step * self.schedule_decay)
        )
        mu_next = self.beta1 * (
            _scaler_one
            - Tensor(0.5, ms.float32) * ops.pow(Tensor(0.96, ms.float32), (step + _scaler_one) * self.schedule_decay)
        )
        mu_schedule = self.mu_schedule * mu
        mu_schedule_next = self.mu_schedule * mu * mu_next
        self.mu_schedule = mu_schedule
        beta2_power = self.beta2_power * self.beta2
        self.beta2_power = beta2_power

        num_params = len(params)
        for i in range(num_params):
            ops.assign(self.moments1[i], self.beta1 * self.moments1[i] + (_scaler_one - self.beta1) * gradients[i])
            ops.assign(
                self.moments2[i], self.beta2 * self.moments2[i] + (_scaler_one - self.beta2) * ops.square(gradients[i])
            )

            regulate_m = mu_next * self.moments1[i] / (_scaler_one - mu_schedule_next) + (_scaler_one - mu) * gradients[
                i
            ] / (_scaler_one - mu_schedule)
            regulate_v = self.moments2[i] / (_scaler_one - beta2_power)

            update = params[i] - lr * regulate_m / (self.eps + ops.sqrt(regulate_v))
            ops.assign(params[i], update)

        return params

mindcv.optim.nadam.NAdam.get_lr()

The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based on :class:mindspore.nn.Optimizer can also call this interface before updating the parameters.

RETURNS DESCRIPTION

float, the learning rate of current step.

Source code in mindcv/optim/nadam.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def get_lr(self):
    """
    The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
    on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.

    Returns:
        float, the learning rate of current step.
    """
    lr = self.learning_rate
    if self.dynamic_lr:
        if self.is_group_lr:
            lr = ()
            for learning_rate in self.learning_rate:
                current_dynamic_lr = learning_rate(self.global_step).reshape(())
                lr += (current_dynamic_lr,)
        else:
            lr = self.learning_rate(self.global_step).reshape(())
    if self._is_dynamic_lr_or_weight_decay():
        self.assignadd(self.global_step, self.global_step_increase_tensor)
    return lr