跳转至

Utility

Logger

mindcv.utils.logger.set_logger(name=None, output_dir=None, rank=0, log_level=logging.INFO, color=True)

Initialize the logger.

If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will be directly returned. During initialization, only logger of the master process is added console handler. If output_dir is specified, all loggers will be added file handler.

PARAMETER DESCRIPTION
name

Logger name. Defaults to None to set up root logger.

TYPE: Optional[str] DEFAULT: None

output_dir

The directory to save log.

TYPE: Optional[str] DEFAULT: None

rank

Process rank in the distributed training. Defaults to 0.

TYPE: int DEFAULT: 0

log_level

Verbosity level of the logger. Defaults to logging.INFO.

TYPE: int DEFAULT: INFO

color

If True, color the output. Defaults to True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Logger

logging.Logger: A initialized logger.

Source code in mindcv/utils/logger.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def set_logger(
    name: Optional[str] = None,
    output_dir: Optional[str] = None,
    rank: int = 0,
    log_level: int = logging.INFO,
    color: bool = True,
) -> logging.Logger:
    """Initialize the logger.

    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, only logger of the master
    process is added console handler. If ``output_dir`` is specified, all loggers
    will be added file handler.

    Args:
        name: Logger name. Defaults to None to set up root logger.
        output_dir: The directory to save log.
        rank: Process rank in the distributed training. Defaults to 0.
        log_level: Verbosity level of the logger. Defaults to ``logging.INFO``.
        color: If True, color the output. Defaults to True.

    Returns:
        logging.Logger: A initialized logger.
    """
    rank = 0 if rank is None else rank
    if name in logger_initialized:
        return logger_initialized[name]

    # get root logger if name is None
    logger = logging.getLogger(name)
    logger.setLevel(log_level)
    # the messages of this logger will not be propagated to its parent
    logger.propagate = False

    fmt = "%(asctime)s %(name)s %(levelname)s - %(message)s"
    datefmt = "[%Y-%m-%d %H:%M:%S]"

    # create console handler for master process
    if rank == 0:
        if color:
            if has_rich:
                console_handler = RichHandler(level=log_level, log_time_format=datefmt)
            elif has_termcolor:
                console_handler = logging.StreamHandler(stream=sys.stdout)
                console_handler.setLevel(log_level)
                console_handler.setFormatter(_ColorfulFormatter(fmt=fmt, datefmt=datefmt))
            else:
                raise NotImplementedError("If you want color, 'rich' or 'termcolor' has to be installed!")
        else:
            console_handler = logging.StreamHandler(stream=sys.stdout)
            console_handler.setLevel(log_level)
            console_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt))
        logger.addHandler(console_handler)

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(output_dir, f"rank{rank}.log"))
        file_handler.setLevel(log_level)
        file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt))
        logger.addHandler(file_handler)

    logger_initialized[name] = logger
    return logger

Callbacks

mindcv.utils.callbacks.StateMonitor

Bases: Callback

Train loss and validation accuracy monitor, after each epoch save the best checkpoint file with the highest validation accuracy.

Source code in mindcv/utils/callbacks.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class StateMonitor(Callback):
    """
    Train loss and validation accuracy monitor, after each epoch save the
    best checkpoint file with the highest validation accuracy.
    """

    def __init__(
        self,
        model,
        model_name="",
        model_ema=False,
        last_epoch=0,
        dataset_sink_mode=True,
        dataset_val=None,
        metric_name=("accuracy",),
        val_interval=1,
        val_start_epoch=1,
        save_best_ckpt=True,
        ckpt_save_dir="./",
        ckpt_save_interval=1,
        ckpt_save_policy=None,
        ckpt_keep_max=10,
        summary_dir="./",
        log_interval=100,
        rank_id=None,
        device_num=None,
    ):
        super().__init__()
        # model
        self.model = model
        self.model_name = model_name
        self.model_ema = model_ema
        self.last_epoch = last_epoch
        self.dataset_sink_mode = dataset_sink_mode
        # evaluation
        self.dataset_val = dataset_val
        self.metric_name = metric_name
        self.val_interval = val_interval
        self.val_start_epoch = val_start_epoch
        # logging
        self.best_res = 0
        self.best_epoch = -1
        self.save_best_ckpt = save_best_ckpt
        self.ckpt_save_dir = ckpt_save_dir
        self.ckpt_save_interval = ckpt_save_interval
        self.ckpt_save_policy = ckpt_save_policy
        self.ckpt_keep_max = ckpt_keep_max
        self.ckpt_manager = CheckpointManager(ckpt_save_policy=self.ckpt_save_policy)
        self._need_flush_from_cache = True
        self.summary_dir = summary_dir
        self.log_interval = log_interval
        # system
        self.rank_id = rank_id if rank_id is not None else 0
        self.device_num = device_num if rank_id is not None else 1
        if self.rank_id in [0, None]:
            os.makedirs(ckpt_save_dir, exist_ok=True)
            self.log_file = os.path.join(ckpt_save_dir, "result.log")
            log_line = "".join(
                f"{s:<20}" for s in ["Epoch", "TrainLoss", *metric_name, "TrainTime", "EvalTime", "TotalTime"]
            )
            with open(self.log_file, "w", encoding="utf-8") as fp:  # writing the title of result.log
                fp.write(log_line + "\n")
        if self.device_num > 1:
            self.all_reduce = AllReduceSum()
        # timestamp
        self.step_ts = None
        self.epoch_ts = None
        self.step_time_accum = 0
        # model_ema
        if self.model_ema:
            self.hyper_map = ops.HyperMap()
            self.online_params = ParameterTuple(self.model.train_network.get_parameters())
            self.swap_params = self.online_params.clone("swap", "zeros")

    def __enter__(self):
        self.summary_record = SummaryRecord(self.summary_dir)
        return self

    def __exit__(self, *exc_args):
        self.summary_record.close()

    def apply_eval(self, run_context):
        """Model evaluation, return validation accuracy."""
        if self.model_ema:
            cb_params = run_context.original_args()
            self.hyper_map(ops.assign, self.swap_params, self.online_params)
            ema_dict = dict()
            net = self._get_network_from_cbp(cb_params)
            for param in net.get_parameters():
                if param.name.startswith("ema"):
                    new_name = param.name.split("ema.")[1]
                    ema_dict[new_name] = param.data
            load_param_into_net(self.model.train_network.network, ema_dict)
            res_dict = self.model.eval(self.dataset_val, dataset_sink_mode=False)
            self.hyper_map(ops.assign, self.online_params, self.swap_params)
        else:
            res_dict = self.model.eval(self.dataset_val, dataset_sink_mode=False)
        res_array = ms.Tensor(list(res_dict.values()), ms.float32)
        if self.device_num > 1:
            res_array = self.all_reduce(res_array)
            res_array /= self.device_num
        res_array = res_array.asnumpy()
        return res_array

    def on_train_step_begin(self, run_context):
        self.step_ts = time()

    def on_train_epoch_begin(self, run_context):
        self.epoch_ts = time()

    def on_train_step_end(self, run_context):
        cb_params = run_context.original_args()
        num_epochs = cb_params.epoch_num
        num_batches = cb_params.batch_num
        # num_steps = num_batches * num_epochs
        # cur_x start from 1, end at num_xs, range: [1, num_xs]
        cur_step = cb_params.cur_step_num + self.last_epoch * num_batches
        cur_epoch = cb_params.cur_epoch_num + self.last_epoch
        cur_batch = (cur_step - 1) % num_batches + 1

        self.step_time_accum += time() - self.step_ts
        if cur_batch % self.log_interval == 0 or cur_batch == num_batches or cur_batch == 1:
            lr = self._get_lr_from_cbp(cb_params)
            loss = self._get_loss_from_cbp(cb_params)
            _logger.info(
                f"Epoch: [{cur_epoch}/{num_epochs}], "
                f"batch: [{cur_batch}/{num_batches}], "
                f"loss: {loss.asnumpy():.6f}, "
                f"lr: {lr.asnumpy():.6f}, "
                f"time: {self.step_time_accum:.6f}s"
            )
            self.step_time_accum = 0

    def on_train_epoch_end(self, run_context):
        """
        After epoch, print train loss and val accuracy,
        save the best ckpt file with the highest validation accuracy.
        """
        cb_params = run_context.original_args()
        num_epochs = cb_params.epoch_num
        num_batches = cb_params.batch_num
        cur_step = cb_params.cur_step_num + self.last_epoch * num_batches
        cur_epoch = cb_params.cur_epoch_num + self.last_epoch
        cur_batch = (cur_step - 1) % num_batches + 1

        train_time = time() - self.epoch_ts
        loss = self._get_loss_from_cbp(cb_params)

        val_time = 0
        res = np.zeros(len(self.metric_name), dtype=np.float32)
        # val while training if validation loader is not None
        if (
            self.dataset_val is not None
            and cur_epoch >= self.val_start_epoch
            and (cur_epoch - self.val_start_epoch) % self.val_interval == 0
        ):
            val_time = time()
            res = self.apply_eval(run_context)
            val_time = time() - val_time
            # record val acc
            metric_str = "Validation "
            for i in range(len(self.metric_name)):
                metric_str += f"{self.metric_name[i]}: {res[i]:.4%}, "
            metric_str += f"time: {val_time:.6f}s"
            _logger.info(metric_str)
            # save the best ckpt file
            if res[0] > self.best_res:
                self.best_res = res[0]
                self.best_epoch = cur_epoch
                _logger.info(f"=> New best val acc: {res[0]:.4%}")

        # save checkpoint
        if self.rank_id in [0, None]:
            if self.save_best_ckpt and self.best_epoch == cur_epoch:  # always save ckpt if cur epoch got best acc
                best_ckpt_save_path = os.path.join(self.ckpt_save_dir, f"{self.model_name}_best.ckpt")
                save_checkpoint(cb_params.train_network, best_ckpt_save_path, async_save=True)
            if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == num_epochs):
                if self._need_flush_from_cache:
                    self._flush_from_cache(cb_params)
                # save optim for resume
                optimizer = self._get_optimizer_from_cbp(cb_params)
                optim_save_path = os.path.join(self.ckpt_save_dir, f"optim_{self.model_name}.ckpt")
                save_checkpoint(optimizer, optim_save_path, async_save=True)
                # keep checkpoint files number equal max number.
                ckpt_save_path = os.path.join(self.ckpt_save_dir, f"{self.model_name}-{cur_epoch}_{cur_batch}.ckpt")
                _logger.info(f"Saving model to {ckpt_save_path}")
                self.ckpt_manager.save_ckpoint(
                    cb_params.train_network,
                    num_ckpt=self.ckpt_keep_max,
                    metric=res[0],
                    save_path=ckpt_save_path,
                )

        # logging
        total_time = time() - self.epoch_ts
        _logger.info(
            f"Total time since last epoch: {total_time:.6f}(train: {train_time:.6f}, val: {val_time:.6f})s, "
            f"ETA: {(num_epochs - cur_epoch) * total_time:.6f}s"
        )
        _logger.info("-" * 80)
        if self.rank_id in [0, None]:
            log_line = "".join(
                f"{s:<20}"
                for s in [
                    f"{cur_epoch}",
                    f"{loss.asnumpy():.6f}",
                    *[f"{i:.4%}" for i in res],
                    f"{train_time:.2f}",
                    f"{val_time:.2f}",
                    f"{total_time:.2f}",
                ]
            )
            with open(self.log_file, "a", encoding="utf-8") as fp:
                fp.write(log_line + "\n")

        # summary
        self.summary_record.add_value("scalar", f"train_loss_{self.rank_id}", loss)
        for i in range(len(res)):
            self.summary_record.add_value(
                "scalar", f"val_{self.metric_name[i]}_{self.rank_id}", Tensor(res[i], dtype=ms.float32)
            )
        self.summary_record.record(cur_step)

    def on_train_end(self, run_context):
        _logger.info("Finish training!")
        if self.dataset_val is not None:
            _logger.info(
                f"The best validation {self.metric_name[0]} is: {self.best_res:.4%} at epoch {self.best_epoch}."
            )
        _logger.info("=" * 80)

    def _get_network_from_cbp(self, cb_params):
        if self.dataset_sink_mode:
            network = cb_params.train_network.network
        else:
            network = cb_params.train_network
        return network

    def _get_optimizer_from_cbp(self, cb_params):
        if cb_params.optimizer is not None:
            optimizer = cb_params.optimizer
        elif self.dataset_sink_mode:
            optimizer = cb_params.train_network.network.optimizer
        else:
            optimizer = cb_params.train_network.optimizer
        return optimizer

    def _get_lr_from_cbp(self, cb_params):
        optimizer = self._get_optimizer_from_cbp(cb_params)
        if optimizer.global_step < 1:
            _logger.warning(
                "`global_step` of optimizer is less than 1. It seems to be a overflow at the first step. "
                "If you keep seeing this message, it means that the optimizer never actually called."
            )
            optim_step = Tensor((0,), ms.int32)
        else:  # if the optimizer is successfully called, the global_step will actually be the value of next step.
            optim_step = optimizer.global_step - 1
        if optimizer.dynamic_lr:
            if isinstance(optimizer.learning_rate, ms.nn.CellList):
                # return the learning rates of the first parameter if dynamic_lr
                lr = optimizer.learning_rate[0](optim_step)[0]
            else:
                lr = optimizer.learning_rate(optim_step)[0]
        else:
            lr = optimizer.learning_rate
        return lr

    def _get_loss_from_cbp(self, cb_params):
        """
        Get loss from the network output.
        Args:
            cb_params (_InternalCallbackParam): Callback parameters.
        Returns:
            Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
        """
        output = cb_params.net_outputs
        if output is None:
            _logger.warning("Can not find any output by this network, so SummaryCollector will not collect loss.")
            return None

        if isinstance(output, (int, float, Tensor)):
            loss = output
        elif isinstance(output, (list, tuple)) and output:
            # If the output is a list, since the default network returns loss first,
            # we assume that the first one is loss.
            loss = output[0]
        else:
            _logger.warning(
                "The output type could not be identified, expect type is one of "
                "[int, float, Tensor, list, tuple], so no loss was recorded in SummaryCollector."
            )
            return None

        if not isinstance(loss, Tensor):
            loss = Tensor(loss)

        loss = Tensor(np.mean(loss.asnumpy()))
        return loss

    def _flush_from_cache(self, cb_params):
        """Flush cache data to host if tensor is cache enable."""
        has_cache_params = False
        params = cb_params.train_network.get_parameters()
        for param in params:
            if param.cache_enable:
                has_cache_params = True
                Tensor(param).flush_from_cache()
        if not has_cache_params:
            self._need_flush_from_cache = False

mindcv.utils.callbacks.StateMonitor.apply_eval(run_context)

Model evaluation, return validation accuracy.

Source code in mindcv/utils/callbacks.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def apply_eval(self, run_context):
    """Model evaluation, return validation accuracy."""
    if self.model_ema:
        cb_params = run_context.original_args()
        self.hyper_map(ops.assign, self.swap_params, self.online_params)
        ema_dict = dict()
        net = self._get_network_from_cbp(cb_params)
        for param in net.get_parameters():
            if param.name.startswith("ema"):
                new_name = param.name.split("ema.")[1]
                ema_dict[new_name] = param.data
        load_param_into_net(self.model.train_network.network, ema_dict)
        res_dict = self.model.eval(self.dataset_val, dataset_sink_mode=False)
        self.hyper_map(ops.assign, self.online_params, self.swap_params)
    else:
        res_dict = self.model.eval(self.dataset_val, dataset_sink_mode=False)
    res_array = ms.Tensor(list(res_dict.values()), ms.float32)
    if self.device_num > 1:
        res_array = self.all_reduce(res_array)
        res_array /= self.device_num
    res_array = res_array.asnumpy()
    return res_array

mindcv.utils.callbacks.StateMonitor.on_train_epoch_end(run_context)

After epoch, print train loss and val accuracy, save the best ckpt file with the highest validation accuracy.

Source code in mindcv/utils/callbacks.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def on_train_epoch_end(self, run_context):
    """
    After epoch, print train loss and val accuracy,
    save the best ckpt file with the highest validation accuracy.
    """
    cb_params = run_context.original_args()
    num_epochs = cb_params.epoch_num
    num_batches = cb_params.batch_num
    cur_step = cb_params.cur_step_num + self.last_epoch * num_batches
    cur_epoch = cb_params.cur_epoch_num + self.last_epoch
    cur_batch = (cur_step - 1) % num_batches + 1

    train_time = time() - self.epoch_ts
    loss = self._get_loss_from_cbp(cb_params)

    val_time = 0
    res = np.zeros(len(self.metric_name), dtype=np.float32)
    # val while training if validation loader is not None
    if (
        self.dataset_val is not None
        and cur_epoch >= self.val_start_epoch
        and (cur_epoch - self.val_start_epoch) % self.val_interval == 0
    ):
        val_time = time()
        res = self.apply_eval(run_context)
        val_time = time() - val_time
        # record val acc
        metric_str = "Validation "
        for i in range(len(self.metric_name)):
            metric_str += f"{self.metric_name[i]}: {res[i]:.4%}, "
        metric_str += f"time: {val_time:.6f}s"
        _logger.info(metric_str)
        # save the best ckpt file
        if res[0] > self.best_res:
            self.best_res = res[0]
            self.best_epoch = cur_epoch
            _logger.info(f"=> New best val acc: {res[0]:.4%}")

    # save checkpoint
    if self.rank_id in [0, None]:
        if self.save_best_ckpt and self.best_epoch == cur_epoch:  # always save ckpt if cur epoch got best acc
            best_ckpt_save_path = os.path.join(self.ckpt_save_dir, f"{self.model_name}_best.ckpt")
            save_checkpoint(cb_params.train_network, best_ckpt_save_path, async_save=True)
        if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == num_epochs):
            if self._need_flush_from_cache:
                self._flush_from_cache(cb_params)
            # save optim for resume
            optimizer = self._get_optimizer_from_cbp(cb_params)
            optim_save_path = os.path.join(self.ckpt_save_dir, f"optim_{self.model_name}.ckpt")
            save_checkpoint(optimizer, optim_save_path, async_save=True)
            # keep checkpoint files number equal max number.
            ckpt_save_path = os.path.join(self.ckpt_save_dir, f"{self.model_name}-{cur_epoch}_{cur_batch}.ckpt")
            _logger.info(f"Saving model to {ckpt_save_path}")
            self.ckpt_manager.save_ckpoint(
                cb_params.train_network,
                num_ckpt=self.ckpt_keep_max,
                metric=res[0],
                save_path=ckpt_save_path,
            )

    # logging
    total_time = time() - self.epoch_ts
    _logger.info(
        f"Total time since last epoch: {total_time:.6f}(train: {train_time:.6f}, val: {val_time:.6f})s, "
        f"ETA: {(num_epochs - cur_epoch) * total_time:.6f}s"
    )
    _logger.info("-" * 80)
    if self.rank_id in [0, None]:
        log_line = "".join(
            f"{s:<20}"
            for s in [
                f"{cur_epoch}",
                f"{loss.asnumpy():.6f}",
                *[f"{i:.4%}" for i in res],
                f"{train_time:.2f}",
                f"{val_time:.2f}",
                f"{total_time:.2f}",
            ]
        )
        with open(self.log_file, "a", encoding="utf-8") as fp:
            fp.write(log_line + "\n")

    # summary
    self.summary_record.add_value("scalar", f"train_loss_{self.rank_id}", loss)
    for i in range(len(res)):
        self.summary_record.add_value(
            "scalar", f"val_{self.metric_name[i]}_{self.rank_id}", Tensor(res[i], dtype=ms.float32)
        )
    self.summary_record.record(cur_step)

mindcv.utils.callbacks.ValCallback

Bases: Callback

Source code in mindcv/utils/callbacks.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
class ValCallback(Callback):
    def __init__(self, log_interval=100):
        super().__init__()
        self.log_interval = log_interval
        self.ts = time()

    def on_eval_step_end(self, run_context):
        cb_params = run_context.original_args()
        num_batches = cb_params.batch_num
        cur_step = cb_params.cur_step_num

        if cur_step % self.log_interval == 0 or cur_step == num_batches:
            print(f"batch: {cur_step}/{num_batches}, time: {time() - self.ts:.6f}s")
            self.ts = time()

Train Step

mindcv.utils.train_step.TrainStep

Bases: TrainOneStepWithLossScaleCell

Training step with loss scale.

The customized trainOneStepCell also supported following algorithms
  • Exponential Moving Average (EMA)
  • Gradient Clipping
  • Gradient Accumulation
Source code in mindcv/utils/train_step.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class TrainStep(nn.TrainOneStepWithLossScaleCell):
    """Training step with loss scale.

    The customized trainOneStepCell also supported following algorithms:
        * Exponential Moving Average (EMA)
        * Gradient Clipping
        * Gradient Accumulation
    """

    def __init__(
        self,
        network,
        optimizer,
        scale_sense=1.0,
        ema=False,
        ema_decay=0.9999,
        clip_grad=False,
        clip_value=15.0,
        gradient_accumulation_steps=1,
    ):
        super(TrainStep, self).__init__(network, optimizer, scale_sense)
        self.ema = ema
        self.ema_decay = ema_decay
        self.updates = Parameter(Tensor(0.0, ms.float32))
        self.clip_grad = clip_grad
        self.clip_value = clip_value
        if self.ema:
            self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
            self.ema_weight = self.weights_all.clone("ema", init="same")

        self.accumulate_grad = gradient_accumulation_steps > 1
        if self.accumulate_grad:
            self.gradient_accumulation = GradientAccumulation(gradient_accumulation_steps, optimizer, self.grad_reducer)

    def ema_update(self):
        self.updates += 1
        # ema factor is corrected by (1 - exp(-t/T)), where `t` means time and `T` means temperature.
        ema_decay = self.ema_decay * (1 - F.exp(-self.updates / 2000))
        # update trainable parameters
        success = self.hyper_map(F.partial(_ema_op, ema_decay), self.ema_weight, self.weights_all)
        return success

    def construct(self, *inputs):
        weights = self.weights
        loss = self.network(*inputs)
        scaling_sens = self.scale_sense

        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

        scaling_sens_filled = ops.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)

        # todo: When to clip grad? Do we need to clip grad after grad reduction? What if grad accumulation is needed?
        if self.clip_grad:
            grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)

        if self.loss_scaling_manager:  # scale_sense = update_cell: Cell --> TrainOneStepWithLossScaleCell.construct
            if self.accumulate_grad:
                # todo: GradientAccumulation only call grad_reducer at the step where the accumulation is completed.
                #  So checking the overflow status is after gradient reduction, is this correct?
                # get the overflow buffer
                cond = self.get_overflow_status(status, grads)
                overflow = self.process_loss_scale(cond)
                # if there is no overflow, do optimize
                if not overflow:
                    loss = self.gradient_accumulation(loss, grads)
                    if self.ema:
                        loss = F.depend(loss, self.ema_update())
            else:
                # apply grad reducer on grads
                grads = self.grad_reducer(grads)
                # get the overflow buffer
                cond = self.get_overflow_status(status, grads)
                overflow = self.process_loss_scale(cond)
                # if there is no overflow, do optimize
                if not overflow:
                    loss = F.depend(loss, self.optimizer(grads))
                    if self.ema:
                        loss = F.depend(loss, self.ema_update())
        else:  # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct
            if self.accumulate_grad:
                loss = self.gradient_accumulation(loss, grads)
            else:
                grads = self.grad_reducer(grads)
                loss = F.depend(loss, self.optimizer(grads))

            if self.ema:
                loss = F.depend(loss, self.ema_update())

        return loss

Trainer Factory

mindcv.utils.trainer_factory.create_trainer(network, loss, optimizer, metrics, amp_level, amp_cast_list, loss_scale_type, loss_scale=1.0, drop_overflow_update=False, ema=False, ema_decay=0.9999, clip_grad=False, clip_value=15.0, gradient_accumulation_steps=1)

Create Trainer.

PARAMETER DESCRIPTION
network

The backbone network to train, evaluate or predict.

TYPE: Cell

loss

The function of calculating loss.

TYPE: Cell

optimizer

The optimizer for training.

TYPE: Cell

metrics

The metrics for model evaluation.

TYPE: Union[dict, set]

amp_level

The level of auto mixing precision training.

TYPE: str

amp_cast_list

At the cell level, custom casting the cell to FP16.

TYPE: str

loss_scale_type

The type of loss scale.

TYPE: str

loss_scale

The value of loss scale.

TYPE: float DEFAULT: 1.0

drop_overflow_update

Whether to execute optimizer if there is an overflow.

TYPE: bool DEFAULT: False

ema

Whether to use exponential moving average of model weights.

TYPE: bool DEFAULT: False

ema_decay

Decay factor for model weights moving average.

TYPE: float DEFAULT: 0.9999

clip_grad

whether to gradient clip.

TYPE: bool DEFAULT: False

clip_value

The value at which to clip gradients.

TYPE: float DEFAULT: 15.0

gradient_accumulation_steps

Accumulate the gradients of n batches before update.

TYPE: int DEFAULT: 1

RETURNS DESCRIPTION

mindspore.Model

Source code in mindcv/utils/trainer_factory.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def create_trainer(
    network: nn.Cell,
    loss: nn.Cell,
    optimizer: nn.Cell,
    metrics: Union[dict, set],
    amp_level: str,
    amp_cast_list: str,
    loss_scale_type: str,
    loss_scale: float = 1.0,
    drop_overflow_update: bool = False,
    ema: bool = False,
    ema_decay: float = 0.9999,
    clip_grad: bool = False,
    clip_value: float = 15.0,
    gradient_accumulation_steps: int = 1,
):
    """Create Trainer.

    Args:
        network: The backbone network to train, evaluate or predict.
        loss: The function of calculating loss.
        optimizer: The optimizer for training.
        metrics: The metrics for model evaluation.
        amp_level: The level of auto mixing precision training.
        amp_cast_list: At the cell level, custom casting the cell to FP16.
        loss_scale_type: The type of loss scale.
        loss_scale: The value of loss scale.
        drop_overflow_update: Whether to execute optimizer if there is an overflow.
        ema: Whether to use exponential moving average of model weights.
        ema_decay: Decay factor for model weights moving average.
        clip_grad: whether to gradient clip.
        clip_value: The value at which to clip gradients.
        gradient_accumulation_steps: Accumulate the gradients of n batches before update.

    Returns:
        mindspore.Model

    """
    if loss_scale < 1.0:
        raise ValueError("Loss scale cannot be less than 1.0!")

    if drop_overflow_update is False and loss_scale_type.lower() == "dynamic":
        raise ValueError("DynamicLossScale ALWAYS drop overflow!")

    if gradient_accumulation_steps < 1:
        raise ValueError("`gradient_accumulation_steps` must be >= 1!")

    if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
        mindspore_kwargs = dict(
            network=network,
            loss_fn=loss,
            optimizer=optimizer,
            metrics=metrics,
            amp_level=amp_level,
        )
        if loss_scale_type.lower() == "fixed":
            mindspore_kwargs["loss_scale_manager"] = FixedLossScaleManager(
                loss_scale=loss_scale, drop_overflow_update=drop_overflow_update
            )
        elif loss_scale_type.lower() == "dynamic":
            mindspore_kwargs["loss_scale_manager"] = DynamicLossScaleManager(
                init_loss_scale=loss_scale, scale_factor=2, scale_window=2000
            )
        elif loss_scale_type.lower() == "auto":
            # We don't explicitly construct LossScaleManager
            _logger.warning(
                "You are using AUTO loss scale, which means the LossScaleManager isn't explicitly pass in "
                "when creating a mindspore.Model instance. "
                "NOTE: mindspore.Model may use LossScaleManager silently. See mindspore.train.amp for details."
            )
        else:
            raise ValueError(f"Loss scale type only support ['fixed', 'dynamic', 'auto'], but got{loss_scale_type}.")
        model = Model(**mindspore_kwargs)
    else:  # require customized train step
        eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
        auto_mixed_precision(network, amp_level, amp_cast_list)
        net_with_loss = add_loss_network(network, loss, amp_level)
        train_step_kwargs = dict(
            network=net_with_loss,
            optimizer=optimizer,
            ema=ema,
            ema_decay=ema_decay,
            clip_grad=clip_grad,
            clip_value=clip_value,
            gradient_accumulation_steps=gradient_accumulation_steps,
        )
        if loss_scale_type.lower() == "fixed":
            loss_scale_manager = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=drop_overflow_update)
        elif loss_scale_type.lower() == "dynamic":
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=loss_scale, scale_factor=2, scale_window=2000)
        else:
            raise ValueError(f"Loss scale type only support ['fixed', 'dynamic'], but got{loss_scale_type}.")
        update_cell = loss_scale_manager.get_update_cell()
        # 1. loss_scale_type="fixed", drop_overflow_update=False
        # --> update_cell=None, TrainStep=TrainOneStepCell(scale_sense=loss_scale)
        # 2. loss_scale_type: fixed, drop_overflow_update: True
        # --> update_cell=FixedLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
        # 3. loss_scale_type: dynamic, drop_overflow_update: True
        # --> update_cell=DynamicLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
        if update_cell is None:
            train_step_kwargs["scale_sense"] = Tensor(loss_scale, dtype=ms.float32)
        else:
            if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
                raise ValueError(
                    "Only `loss_scale_type` is `fixed` and `drop_overflow_update` is `False`"
                    "are supported on device `CPU`."
                )
            train_step_kwargs["scale_sense"] = update_cell
        train_step_cell = TrainStep(**train_step_kwargs).set_train()
        model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2])
        # todo: do we need to set model._loss_scale_manager
    return model