Skip to content

Learning Rate Scheduler

Scheduler Factory

mindcv.scheduler.scheduler_factory.create_scheduler(steps_per_epoch, scheduler='constant', lr=0.01, min_lr=1e-06, warmup_epochs=3, warmup_factor=0.0, decay_epochs=10, decay_rate=0.9, milestones=None, num_epochs=200, num_cycles=1, cycle_decay=1.0, lr_epoch_stair=False)

Creates learning rate scheduler by name.

PARAMETER DESCRIPTION
steps_per_epoch

number of steps per epoch.

TYPE: int

scheduler

scheduler name like 'constant', 'cosine_decay', 'step_decay', 'exponential_decay', 'polynomial_decay', 'multi_step_decay'. Default: 'constant'.

TYPE: str DEFAULT: 'constant'

lr

learning rate value. Default: 0.01.

TYPE: float DEFAULT: 0.01

min_lr

lower lr bound for 'cosine_decay' schedulers. Default: 1e-6.

TYPE: float DEFAULT: 1e-06

warmup_epochs

epochs to warmup LR, if scheduler supports. Default: 3.

TYPE: int DEFAULT: 3

warmup_factor

the warmup phase of scheduler is a linearly increasing lr, the beginning factor is warmup_factor, i.e., the lr of the first step/epoch is lr*warmup_factor, and the ending lr in the warmup phase is lr. Default: 0.0

TYPE: float DEFAULT: 0.0

decay_epochs

for 'cosine_decay' schedulers, decay LR to min_lr in decay_epochs. For 'step_decay' scheduler, decay LR by a factor of decay_rate every decay_epochs. Default: 10.

TYPE: int DEFAULT: 10

decay_rate

LR decay rate. Default: 0.9.

TYPE: float DEFAULT: 0.9

milestones

list of epoch milestones for 'multi_step_decay' scheduler. Must be increasing. Default: None

TYPE: list DEFAULT: None

num_epochs

Number of total epochs. Default: 200.

TYPE: int DEFAULT: 200

num_cycles

Number of cycles for cosine decay and cyclic. Default: 1.

TYPE: int DEFAULT: 1

cycle_decay

Decay rate of lr max in each cosine cycle. Default: 1.0.

TYPE: float DEFAULT: 1.0

lr_epoch_stair

If True, LR will be updated in the beginning of each new epoch and the LR will be consistent for each batch in one epoch. Otherwise, learning rate will be updated dynamically in each step. Default: False.

TYPE: bool DEFAULT: False

Source code in mindcv/scheduler/scheduler_factory.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def create_scheduler(
    steps_per_epoch: int,
    scheduler: str = "constant",
    lr: float = 0.01,
    min_lr: float = 1e-6,
    warmup_epochs: int = 3,
    warmup_factor: float = 0.0,
    decay_epochs: int = 10,
    decay_rate: float = 0.9,
    milestones: list = None,
    num_epochs: int = 200,
    num_cycles: int = 1,
    cycle_decay: float = 1.0,
    lr_epoch_stair: bool = False,
):
    r"""Creates learning rate scheduler by name.

    Args:
        steps_per_epoch: number of steps per epoch.
        scheduler: scheduler name like 'constant', 'cosine_decay', 'step_decay',
            'exponential_decay', 'polynomial_decay', 'multi_step_decay'. Default: 'constant'.
        lr: learning rate value. Default: 0.01.
        min_lr: lower lr bound for 'cosine_decay' schedulers. Default: 1e-6.
        warmup_epochs: epochs to warmup LR, if scheduler supports. Default: 3.
        warmup_factor: the warmup phase of scheduler is a linearly increasing lr,
            the beginning factor is `warmup_factor`, i.e., the lr of the first step/epoch is lr*warmup_factor,
            and the ending lr in the warmup phase is lr. Default: 0.0
        decay_epochs: for 'cosine_decay' schedulers, decay LR to min_lr in `decay_epochs`.
            For 'step_decay' scheduler, decay LR by a factor of `decay_rate` every `decay_epochs`. Default: 10.
        decay_rate: LR decay rate. Default: 0.9.
        milestones: list of epoch milestones for 'multi_step_decay' scheduler. Must be increasing. Default: None
        num_epochs: Number of total epochs. Default: 200.
        num_cycles: Number of cycles for cosine decay and cyclic. Default: 1.
        cycle_decay: Decay rate of lr max in each cosine cycle. Default: 1.0.
        lr_epoch_stair: If True, LR will be updated in the beginning of each new epoch
            and the LR will be consistent for each batch in one epoch.
            Otherwise, learning rate will be updated dynamically in each step. Default: False.
    Returns:
        Cell object for computing LR with input of current global steps
    """
    # check params
    if milestones is None:
        milestones = []

    if warmup_epochs + decay_epochs > num_epochs:
        _logger.warning("warmup_epochs + decay_epochs > num_epochs. Please check and reduce decay_epochs!")

    # lr warmup phase
    warmup_lr_scheduler = []
    if warmup_epochs > 0:
        if warmup_factor == 0 and lr_epoch_stair:
            _logger.warning(
                "The warmup factor is set to 0, lr of 0-th epoch is always zero! " "Recommend value is 0.01."
            )
        warmup_func = linear_lr if lr_epoch_stair else linear_refined_lr
        warmup_lr_scheduler = warmup_func(
            start_factor=warmup_factor,
            end_factor=1.0,
            total_iters=warmup_epochs,
            lr=lr,
            steps_per_epoch=steps_per_epoch,
            epochs=warmup_epochs,
        )

    # lr decay phase
    main_epochs = num_epochs - warmup_epochs
    if scheduler in ["cosine_decay", "warmup_cosine_decay"]:
        cosine_func = cosine_decay_lr if lr_epoch_stair else cosine_decay_refined_lr
        main_lr_scheduler = cosine_func(
            decay_epochs=decay_epochs,
            eta_min=min_lr,
            eta_max=lr,
            steps_per_epoch=steps_per_epoch,
            epochs=main_epochs,
            num_cycles=num_cycles,
            cycle_decay=cycle_decay,
        )
    elif scheduler == "one_cycle":
        if lr_epoch_stair or warmup_epochs > 0:
            raise ValueError(
                "OneCycle scheduler doesn't support learning rate varies with epoch and warmup_epochs > 0."
            )
        div_factor = 25.0
        initial_lr = lr / div_factor
        final_div_factor = initial_lr / min_lr
        main_lr_scheduler = one_cycle_lr(
            max_lr=lr,
            final_div_factor=final_div_factor,
            steps_per_epoch=steps_per_epoch,
            epochs=main_epochs,
        )
    elif scheduler == "cyclic":
        if lr_epoch_stair or warmup_epochs > 0:
            raise ValueError("Cyclic scheduler doesn't support learning rate varies with epoch and warmup_epochs > 0.")
        num_steps = steps_per_epoch * main_epochs
        step_size_up = int(num_steps / num_cycles / 2)
        main_lr_scheduler = cyclic_lr(
            base_lr=min_lr,
            max_lr=lr,
            step_size_up=step_size_up,
            steps_per_epoch=steps_per_epoch,
            epochs=main_epochs,
        )
    elif scheduler == "exponential_decay":
        exponential_func = exponential_lr if lr_epoch_stair else exponential_refined_lr
        main_lr_scheduler = exponential_func(
            gamma=decay_rate, lr=lr, steps_per_epoch=steps_per_epoch, epochs=main_epochs
        )
    elif scheduler == "polynomial_decay":
        polynomial_func = polynomial_lr if lr_epoch_stair else polynomial_refined_lr
        main_lr_scheduler = polynomial_func(
            total_iters=main_epochs, power=decay_rate, lr=lr, steps_per_epoch=steps_per_epoch, epochs=main_epochs
        )
    elif scheduler == "step_decay":
        main_lr_scheduler = step_lr(
            step_size=decay_epochs, gamma=decay_rate, lr=lr, steps_per_epoch=steps_per_epoch, epochs=main_epochs
        )
    elif scheduler == "multi_step_decay":
        main_lr_scheduler = multi_step_lr(
            milestones=milestones, gamma=decay_rate, lr=lr, steps_per_epoch=steps_per_epoch, epochs=main_epochs
        )
    elif scheduler == "constant":
        main_lr_scheduler = [lr for _ in range(steps_per_epoch * main_epochs)]
    else:
        raise ValueError(f"Invalid scheduler: {scheduler}")

    # combine
    lr_scheduler = warmup_lr_scheduler + main_lr_scheduler

    return lr_scheduler

mindcv.scheduler.dynamic_lr

Meta learning rate scheduler.

This module implements exactly the same learning rate scheduler as native PyTorch, see "torch.optim.lr_scheduler" <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>_. At present, only constant_lr, linear_lr, polynomial_lr, exponential_lr, step_lr, multi_step_lr, cosine_annealing_lr, cosine_annealing_warm_restarts_lr, one_cycle_lr, cyclic_lr are implemented. The number, name and usage of the Positional Arguments are exactly the same as those of native PyTorch.

However, due to the constraint of having to explicitly return the learning rate at each step, we have to introduce additional Keyword Arguments. There are only three Keyword Arguments introduced, namely lr, steps_per_epoch and epochs, explained as follows: lr: the basic learning rate when creating optim in torch. steps_per_epoch: the number of steps(iterations) of each epoch. epochs: the number of epoch. It and steps_per_epoch determine the length of the returned lrs.

In all schedulers, one_cycle_lr and cyclic_lr only need two Keyword Arguments except lr, since when creating optim in torch, lr argument will have no effect if using the two schedulers above.

Since most scheduler in PyTorch are coarse-grained, that is the learning rate is constant within a single epoch. For non-stepwise scheduler, we introduce several fine-grained variation, that is the learning rate is also changed within a single epoch. The function name of these variants have the refined keyword. The implemented fine-grained variation are list as follows: linear_refined_lr, polynomial_refined_lr, etc.

mindcv.scheduler.dynamic_lr.cosine_decay_lr(decay_epochs, eta_min, *, eta_max, steps_per_epoch, epochs, num_cycles=1, cycle_decay=1.0)

update every epoch

Source code in mindcv/scheduler/dynamic_lr.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def cosine_decay_lr(decay_epochs, eta_min, *, eta_max, steps_per_epoch, epochs, num_cycles=1, cycle_decay=1.0):
    """update every epoch"""
    tot_steps = steps_per_epoch * epochs
    lrs = []

    for c in range(num_cycles):
        lr_max = eta_max * (cycle_decay**c)
        delta = 0.5 * (lr_max - eta_min)
        for i in range(steps_per_epoch * decay_epochs):
            t_cur = math.floor(i / steps_per_epoch)
            t_cur = min(t_cur, decay_epochs)
            lr_cur = eta_min + delta * (1.0 + math.cos(math.pi * t_cur / decay_epochs))
            if len(lrs) < tot_steps:
                lrs.append(lr_cur)
            else:
                break

    if epochs > num_cycles * decay_epochs:
        for i in range((epochs - (num_cycles * decay_epochs)) * steps_per_epoch):
            lrs.append(eta_min)

    return lrs

mindcv.scheduler.dynamic_lr.cosine_decay_refined_lr(decay_epochs, eta_min, *, eta_max, steps_per_epoch, epochs, num_cycles=1, cycle_decay=1.0)

update every step

Source code in mindcv/scheduler/dynamic_lr.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def cosine_decay_refined_lr(decay_epochs, eta_min, *, eta_max, steps_per_epoch, epochs, num_cycles=1, cycle_decay=1.0):
    """update every step"""
    tot_steps = steps_per_epoch * epochs
    lrs = []

    for c in range(num_cycles):
        lr_max = eta_max * (cycle_decay**c)
        delta = 0.5 * (lr_max - eta_min)
        for i in range(steps_per_epoch * decay_epochs):
            t_cur = i / steps_per_epoch
            t_cur = min(t_cur, decay_epochs)
            lr_cur = eta_min + delta * (1.0 + math.cos(math.pi * t_cur / decay_epochs))
            if len(lrs) < tot_steps:
                lrs.append(lr_cur)
            else:
                break

    if epochs > num_cycles * decay_epochs:
        for i in range((epochs - (num_cycles * decay_epochs)) * steps_per_epoch):
            lrs.append(eta_min)

    return lrs

mindcv.scheduler.dynamic_lr.cyclic_lr(base_lr, max_lr, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', *, steps_per_epoch, epochs)

Cyclic learning rate scheduler based on '"Cyclical Learning Rates for Training Neural Networks" https://arxiv.org/abs/1708.07120'

PARAMETER DESCRIPTION
base_lr

Lower learning rate boundaries in each cycle.

TYPE: float

max_lr

Upper learning rate boundaries in each cycle.

TYPE: float

step_size_up

Number of steps in the increasing half in each cycle. Default: 2000.

TYPE: int DEFAULT: 2000

step_size_down

Number of steps in the increasing half in each cycle. If step_size_down is None, it's set to step_size_up. Default: None.

DEFAULT: None

div_factor

Initial learning rate via initial_lr = max_lr / div_factor. Default: 25.0.

final_div_factor

Minimum learning rate at the end via min_lr = initial_lr / final_div_factor. Default: 10000.0.

mode

One of {triangular, triangular2, exp_range}. If scale_fn is not None, it's set to None. Default: 'triangular'.

TYPE: str DEFAULT: 'triangular'

gamma

Constant in 'exp_range' calculating fuction: gamma**(cycle_iterations). Default: 1.0

DEFAULT: 1.0

scale_fn

Custom scaling policy defined by a single argument lambda function. If it's not None, 'mode' is ignored. Default: None

DEFAULT: None

scale_mode

One of {'cycle', 'iterations'}. Determine scale_fn is evaluated on cycle number or cycle iterations. Default: 'cycle'

DEFAULT: 'cycle'

steps_per_epoch

Number of steps per epoch.

TYPE: int

epochs

Number of total epochs.

TYPE: int

Source code in mindcv/scheduler/dynamic_lr.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def cyclic_lr(
    base_lr: float,
    max_lr: float,
    step_size_up: int = 2000,
    step_size_down=None,
    mode: str = "triangular",
    gamma=1.0,
    scale_fn=None,
    scale_mode="cycle",
    *,
    steps_per_epoch: int,
    epochs: int,
):
    """
    Cyclic learning rate scheduler based on
    '"Cyclical Learning Rates for Training Neural Networks" <https://arxiv.org/abs/1708.07120>'

    Args:
        base_lr: Lower learning rate boundaries in each cycle.
        max_lr: Upper learning rate boundaries in each cycle.
        step_size_up: Number of steps in the increasing half in each cycle. Default: 2000.
        step_size_down: Number of steps in the increasing half in each cycle. If step_size_down
            is None, it's set to step_size_up. Default: None.
        div_factor: Initial learning rate via initial_lr = max_lr / div_factor.
            Default: 25.0.
        final_div_factor: Minimum learning rate at the end via
            min_lr = initial_lr / final_div_factor. Default: 10000.0.
        mode: One of {triangular, triangular2, exp_range}. If scale_fn is not None, it's set to
            None. Default: 'triangular'.
        gamma: Constant in 'exp_range' calculating fuction: gamma**(cycle_iterations).
            Default: 1.0
        scale_fn: Custom scaling policy defined by a single argument lambda function. If it's
            not None, 'mode' is ignored. Default: None
        scale_mode: One of {'cycle', 'iterations'}. Determine scale_fn is evaluated on cycle
            number or cycle iterations. Default: 'cycle'
        steps_per_epoch: Number of steps per epoch.
        epochs: Number of total epochs.
    """

    def _triangular_scale_fn(x):
        return 1.0

    def _triangular2_scale_fn(x):
        return 1 / (2.0**(x - 1))

    def _exp_range_scale_fn(x):
        return gamma**x

    steps = steps_per_epoch * epochs
    step_size_up = float(step_size_up)
    step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
    total_size = step_size_up + step_size_down
    step_ratio = step_size_up / total_size
    if scale_fn is None:
        if mode == "triangular":
            scale_fn = _triangular_scale_fn
            scale_mode = "cycle"
        elif mode == "triangular2":
            scale_fn = _triangular2_scale_fn
            scale_mode = "cycle"
        elif mode == "exp_range":
            scale_fn = _exp_range_scale_fn
            scale_mode = "iterations"
    lrs = []
    for i in range(steps):
        cycle = math.floor(1 + i / total_size)
        x = 1.0 + i / total_size - cycle
        if x <= step_ratio:
            scale_factor = x / step_ratio
        else:
            scale_factor = (x - 1) / (step_ratio - 1)
        base_height = (max_lr - base_lr) * scale_factor
        if scale_mode == "cycle":
            lrs.append(base_lr + base_height * scale_fn(cycle))
        else:
            lrs.append(base_lr + base_height * scale_fn(i))
    return lrs

mindcv.scheduler.dynamic_lr.one_cycle_lr(max_lr, pct_start=0.3, anneal_strategy='cos', div_factor=25.0, final_div_factor=10000.0, three_phase=False, *, steps_per_epoch, epochs)

OneCycle learning rate scheduler based on '"Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates" https://arxiv.org/abs/1708.07120'

PARAMETER DESCRIPTION
max_lr

Upper learning rate boundaries in the cycle.

TYPE: float

pct_start

The percentage of the number of steps of increasing learning rate in the cycle. Default: 0.3.

TYPE: float DEFAULT: 0.3

anneal_strategy

Define the annealing strategy: "cos" for cosine annealing, "linear" for linear annealing. Default: "cos".

TYPE: str DEFAULT: 'cos'

div_factor

Initial learning rate via initial_lr = max_lr / div_factor. Default: 25.0.

TYPE: float DEFAULT: 25.0

final_div_factor

Minimum learning rate at the end via min_lr = initial_lr / final_div_factor. Default: 10000.0.

TYPE: float DEFAULT: 10000.0

three_phase

If True, learning rate will be updated by three-phase according to "final_div_factor". Otherwise, learning rate will be updated by two-phase. Default: False.

TYPE: bool DEFAULT: False

steps_per_epoch

Number of steps per epoch.

TYPE: int

epochs

Number of total epochs.

TYPE: int

Source code in mindcv/scheduler/dynamic_lr.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def one_cycle_lr(
    max_lr: float,
    pct_start: float = 0.3,
    anneal_strategy: str = "cos",
    div_factor: float = 25.0,
    final_div_factor: float = 10000.0,
    three_phase: bool = False,
    *,
    steps_per_epoch: int,
    epochs: int,
):
    """
    OneCycle learning rate scheduler based on
    '"Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates"
    <https://arxiv.org/abs/1708.07120>'

    Args:
        max_lr: Upper learning rate boundaries in the cycle.
        pct_start: The percentage of the number of steps of increasing learning rate
            in the cycle. Default: 0.3.
        anneal_strategy: Define the annealing strategy: "cos" for cosine annealing,
            "linear" for linear annealing. Default: "cos".
        div_factor: Initial learning rate via initial_lr = max_lr / div_factor.
            Default: 25.0.
        final_div_factor: Minimum learning rate at the end via
            min_lr = initial_lr / final_div_factor. Default: 10000.0.
        three_phase: If True, learning rate will be updated by three-phase according to
            "final_div_factor". Otherwise, learning rate will be updated by two-phase.
            Default: False.
        steps_per_epoch: Number of steps per epoch.
        epochs: Number of total epochs.
    """

    def _annealing_cos(start, end, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _annealing_linear(start, end, pct):
        return (end - start) * pct + start

    initial_lr = max_lr / div_factor
    min_lr = initial_lr / final_div_factor
    steps = steps_per_epoch * epochs
    step_size_up = float(pct_start * steps) - 1
    step_size_down = float(2 * pct_start * steps) - 2
    step_size_end = float(steps) - 1
    if anneal_strategy == "cos":
        anneal_func = _annealing_cos
    elif anneal_strategy == "linear":
        anneal_func = _annealing_linear
    else:
        raise ValueError(f"anneal_strategy must be one of 'cos' or 'linear', but got {anneal_strategy}")
    lrs = []
    for i in range(steps):
        if three_phase:
            if i <= step_size_up:
                lrs.append(anneal_func(initial_lr, max_lr, i / step_size_up))
            elif step_size_up < i <= step_size_down:
                lrs.append(anneal_func(max_lr, initial_lr, (i - step_size_up) / (step_size_down - step_size_up)))
            else:
                lrs.append(anneal_func(initial_lr, min_lr, (i - step_size_down) / (step_size_end - step_size_down)))
        else:
            if i <= step_size_up:
                lrs.append(anneal_func(initial_lr, max_lr, i / step_size_up))
            else:
                lrs.append(anneal_func(max_lr, min_lr, (i - step_size_up) / (step_size_end - step_size_up)))
    return lrs