跳转至

配置

下载Notebook

MindCV套件可以通过python的argparse库和PyYAML库解析模型的yaml文件来进行参数的配置。 下面我们以squeezenet_1.0模型为例,解释如何配置相应的参数。

基础环境

  1. 参数说明
  • mode:使用静态图模式(0)或动态图模式(1)。

  • distribute:是否使用分布式。

  1. yaml文件样例

    mode: 0
    distribute: True
    ...
    
  2. parse参数设置

    python train.py --mode 0 --distribute False ...
    
  3. 对应的代码示例

    args.mode代表参数mode, args.distribute代表参数distribute

    def train(args):
        ms.set_context(mode=args.mode)
    
        if args.distribute:
            init()
            device_num = get_group_size()
            rank_id = get_rank()
            ms.set_auto_parallel_context(device_num=device_num,
                                         parallel_mode='data_parallel',
                                         gradients_mean=True)
        else:
            device_num = None
            rank_id = None
        ...
    

数据集

  1. 参数说明
  • dataset:数据集名称。

  • data_dir:数据集文件所在路径。

  • shuffle:是否进行数据混洗。

  • dataset_download:是否下载数据集。

  • batch_size:每个批处理数据包含的数据条目。

  • drop_remainder:当最后一个批处理数据包含的数据条目小于 batch_size 时,是否将该批处理丢弃。

  • num_parallel_workers:读取数据的工作线程数。

  1. yaml文件样例

    dataset: 'imagenet'
    data_dir: './imagenet2012'
    shuffle: True
    dataset_download: False
    batch_size: 32
    drop_remainder: True
    num_parallel_workers: 8
    ...
    
  2. parse参数设置

    python train.py ... --dataset imagenet --data_dir ./imagenet2012 --shuffle True \
        --dataset_download False --batch_size 32 --drop_remainder True \
        --num_parallel_workers 8 ...
    
  3. 对应的代码示例

    def train(args):
        ...
        dataset_train = create_dataset(
            name=args.dataset,
            root=args.data_dir,
            split='train',
            shuffle=args.shuffle,
            num_samples=args.num_samples,
            num_shards=device_num,
            shard_id=rank_id,
            num_parallel_workers=args.num_parallel_workers,
            download=args.dataset_download,
            num_aug_repeats=args.aug_repeats)
    
        ...
        target_transform = transforms.OneHot(num_classes) if args.loss == 'BCE' else None
    
        loader_train = create_loader(
            dataset=dataset_train,
            batch_size=args.batch_size,
            drop_remainder=args.drop_remainder,
            is_training=True,
            mixup=args.mixup,
            cutmix=args.cutmix,
            cutmix_prob=args.cutmix_prob,
            num_classes=args.num_classes,
            transform=transform_list,
            target_transform=target_transform,
            num_parallel_workers=args.num_parallel_workers,
        )
        ...
    

数据增强

  1. 参数说明
  • image_resize:图像的输出尺寸大小。

  • scale:要裁剪的原始尺寸大小的各个尺寸的范围。

  • ratio:裁剪宽高比的范围。

  • hfilp:图像被翻转的概率。

  • interpolation:图像插值方式。

  • crop_pct:输入图像中心裁剪百分比。

  • color_jitter:颜色抖动因子(亮度调整因子,对比度调整因子,饱和度调整因子)。

  • re_prob:执行随机擦除的概率。

  1. yaml文件样例

    image_resize: 224
    scale: [0.08, 1.0]
    ratio: [0.75, 1.333]
    hflip: 0.5
    interpolation: 'bilinear'
    crop_pct: 0.875
    color_jitter: [0.4, 0.4, 0.4]
    re_prob: 0.5
    ...
    
  2. parse参数设置

    python train.py ... --image_resize 224 --scale [0.08, 1.0] --ratio [0.75, 1.333] \
        --hflip 0.5 --interpolation "bilinear" --crop_pct 0.875 \
        --color_jitter [0.4, 0.4, 0.4] --re_prob 0.5 ...
    
  3. 对应的代码示例

    def train(args):
        ...
        transform_list = create_transforms(
            dataset_name=args.dataset,
            is_training=True,
            image_resize=args.image_resize,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            interpolation=args.interpolation,
            auto_augment=args.auto_augment,
            mean=args.mean,
            std=args.std,
            re_prob=args.re_prob,
            re_scale=args.re_scale,
            re_ratio=args.re_ratio,
            re_value=args.re_value,
            re_max_attempts=args.re_max_attempts
        )
        ...
    

模型

  1. 参数说明
  • model:模型名称。

  • num_classes:分类的类别数。

  • pretrained:是否加载预训练模型。

  • ckpt_path:参数文件所在的路径。

  • keep_checkpoint_max:最多保存多少个checkpoint文件。

  • ckpt_save_dir:保存参数文件的路径。

  • epoch_size:训练执行轮次。

  • dataset_sink_mode:数据是否直接下沉至处理器进行处理。

  • amp_level:混合精度等级。

  1. yaml文件样例

    model: 'squeezenet1_0'
    num_classes: 1000
    pretrained: False
    ckpt_path: './squeezenet1_0_gpu.ckpt'
    keep_checkpoint_max: 10
    ckpt_save_dir: './ckpt/'
    epoch_size: 200
    dataset_sink_mode: True
    amp_level: 'O0'
    ...
    
  2. parse参数设置

    python train.py ... --model squeezenet1_0 --num_classes 1000 --pretrained False \
        --ckpt_path ./squeezenet1_0_gpu.ckpt --keep_checkpoint_max 10 \
        --ckpt_save_path ./ckpt/ --epoch_size 200 --dataset_sink_mode True \
        --amp_level O0 ...
    
  3. 对应的代码示例

    def train(args):
        ...
        network = create_model(model_name=args.model,
            num_classes=args.num_classes,
            in_channels=args.in_channels,
            drop_rate=args.drop_rate,
            drop_path_rate=args.drop_path_rate,
            pretrained=args.pretrained,
            checkpoint_path=args.ckpt_path,
            ema=args.ema
        )
        ...
    

损失函数

  1. 参数说明
  • loss:损失函数的简称。

  • label_smoothing:标签平滑值,用于计算Loss时防止模型过拟合的正则化手段。

  1. yaml文件样例

    loss: 'CE'
    label_smoothing: 0.1
    ...
    
  2. parse参数设置

    python train.py ... --loss CE --label_smoothing 0.1 ...
    
  3. 对应的代码示例

    def train(args):
        ...
        loss = create_loss(name=args.loss,
            reduction=args.reduction,
            label_smoothing=args.label_smoothing,
            aux_factor=args.aux_factor
        )
        ...
    

学习率策略

  1. 参数说明
  • scheduler:学习率策略的名称。

  • min_lr:学习率的最小值。

  • lr:学习率的最大值。

  • warmup_epochs:学习率warmup的轮次。

  • decay_epochs:进行衰减的step数。

  1. yaml文件样例

    scheduler: 'cosine_decay'
    min_lr: 0.0
    lr: 0.01
    warmup_epochs: 0
    decay_epochs: 200
    ...
    
  2. parse参数设置

    python train.py ... --scheduler cosine_decay --min_lr 0.0 --lr 0.01 \
        --warmup_epochs 0 --decay_epochs 200 ...
    
  3. 对应的代码示例

    def train(args):
        ...
        lr_scheduler = create_scheduler(num_batches,
            scheduler=args.scheduler,
            lr=args.lr,
            min_lr=args.min_lr,
            warmup_epochs=args.warmup_epochs,
            warmup_factor=args.warmup_factor,
            decay_epochs=args.decay_epochs,
            decay_rate=args.decay_rate,
            milestones=args.multi_step_decay_milestones,
            num_epochs=args.epoch_size,
            lr_epoch_stair=args.lr_epoch_stair
        )
        ...
    

优化器

  1. 参数说明
  • opt:优化器名称。

  • weight_decay_filter:权重衰减过滤器 (过滤一些参数, 使其在跟新时不做权重衰减)。

  • momentum:移动平均的动量。

  • weight_decay:权重衰减(L2 penalty)。

  • loss_scale:梯度缩放系数

  • use_nesterov:是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。

  1. yaml文件样例

    opt: 'momentum'
    weight_decay_filter: 'norm_and_bias'
    momentum: 0.9
    weight_decay: 0.00007
    loss_scale: 1024
    use_nesterov: False
    ...
    
  2. parse参数设置

    python train.py ... --opt momentum --weight_decay_filter 'norm_and_bias" --weight_decay 0.00007 \
        --loss_scale 1024 --use_nesterov False ...
    
  3. 对应的代码示例

    def train(args):
        ...
        if args.ema:
            optimizer = create_optimizer(network.trainable_params(),
                opt=args.opt,
                lr=lr_scheduler,
                weight_decay=args.weight_decay,
                momentum=args.momentum,
                nesterov=args.use_nesterov,
                weight_decay_filter=args.weight_decay_filter,
                loss_scale=args.loss_scale,
                checkpoint_path=opt_ckpt_path,
                eps=args.eps
            )
        else:
            optimizer = create_optimizer(network.trainable_params(),
                opt=args.opt,
                lr=lr_scheduler,
                weight_decay=args.weight_decay,
                momentum=args.momentum,
                nesterov=args.use_nesterov,
                weight_decay_filter=args.weight_decay_filter,
                checkpoint_path=opt_ckpt_path,
                eps=args.eps
            )
        ...
    

Yaml和Parse组合使用

使用parse设置参数可以覆盖yaml文件中的参数设置。以下面的shell命令为例,

python train.py -c ./configs/squeezenet/squeezenet_1.0_gpu.yaml --data_dir ./data

上面的命令将args.data_dir参数的值由yaml文件中的 ./imagenet2012 覆盖为 ./data