配置¶
MindCV套件可以通过python的argparse库和PyYAML库解析模型的yaml文件来进行参数的配置。 下面我们以squeezenet_1.0模型为例,解释如何配置相应的参数。
基础环境¶
- 参数说明
-
mode:使用静态图模式(0)或动态图模式(1)。
-
distribute:是否使用分布式。
-
yaml文件样例
mode: 0 distribute: True ...
-
parse参数设置
python train.py --mode 0 --distribute False ...
-
对应的代码示例
args.mode
代表参数mode
,args.distribute
代表参数distribute
。def train(args): ms.set_context(mode=args.mode) if args.distribute: init() device_num = get_group_size() rank_id = get_rank() ms.set_auto_parallel_context(device_num=device_num, parallel_mode='data_parallel', gradients_mean=True) else: device_num = None rank_id = None ...
数据集¶
- 参数说明
-
dataset:数据集名称。
-
data_dir:数据集文件所在路径。
-
shuffle:是否进行数据混洗。
-
dataset_download:是否下载数据集。
-
batch_size:每个批处理数据包含的数据条目。
-
drop_remainder:当最后一个批处理数据包含的数据条目小于 batch_size 时,是否将该批处理丢弃。
-
num_parallel_workers:读取数据的工作线程数。
-
yaml文件样例
dataset: 'imagenet' data_dir: './imagenet2012' shuffle: True dataset_download: False batch_size: 32 drop_remainder: True num_parallel_workers: 8 ...
-
parse参数设置
python train.py ... --dataset imagenet --data_dir ./imagenet2012 --shuffle True \ --dataset_download False --batch_size 32 --drop_remainder True \ --num_parallel_workers 8 ...
-
对应的代码示例
def train(args): ... dataset_train = create_dataset( name=args.dataset, root=args.data_dir, split='train', shuffle=args.shuffle, num_samples=args.num_samples, num_shards=device_num, shard_id=rank_id, num_parallel_workers=args.num_parallel_workers, download=args.dataset_download, num_aug_repeats=args.aug_repeats) ... target_transform = transforms.OneHot(num_classes) if args.loss == 'BCE' else None loader_train = create_loader( dataset=dataset_train, batch_size=args.batch_size, drop_remainder=args.drop_remainder, is_training=True, mixup=args.mixup, cutmix=args.cutmix, cutmix_prob=args.cutmix_prob, num_classes=args.num_classes, transform=transform_list, target_transform=target_transform, num_parallel_workers=args.num_parallel_workers, ) ...
数据增强¶
- 参数说明
-
image_resize:图像的输出尺寸大小。
-
scale:要裁剪的原始尺寸大小的各个尺寸的范围。
-
ratio:裁剪宽高比的范围。
-
hfilp:图像被翻转的概率。
-
interpolation:图像插值方式。
-
crop_pct:输入图像中心裁剪百分比。
-
color_jitter:颜色抖动因子(亮度调整因子,对比度调整因子,饱和度调整因子)。
-
re_prob:执行随机擦除的概率。
-
yaml文件样例
image_resize: 224 scale: [0.08, 1.0] ratio: [0.75, 1.333] hflip: 0.5 interpolation: 'bilinear' crop_pct: 0.875 color_jitter: [0.4, 0.4, 0.4] re_prob: 0.5 ...
-
parse参数设置
python train.py ... --image_resize 224 --scale [0.08, 1.0] --ratio [0.75, 1.333] \ --hflip 0.5 --interpolation "bilinear" --crop_pct 0.875 \ --color_jitter [0.4, 0.4, 0.4] --re_prob 0.5 ...
-
对应的代码示例
def train(args): ... transform_list = create_transforms( dataset_name=args.dataset, is_training=True, image_resize=args.image_resize, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, interpolation=args.interpolation, auto_augment=args.auto_augment, mean=args.mean, std=args.std, re_prob=args.re_prob, re_scale=args.re_scale, re_ratio=args.re_ratio, re_value=args.re_value, re_max_attempts=args.re_max_attempts ) ...
模型¶
- 参数说明
-
model:模型名称。
-
num_classes:分类的类别数。
-
pretrained:是否加载预训练模型。
-
ckpt_path:参数文件所在的路径。
-
keep_checkpoint_max:最多保存多少个checkpoint文件。
-
ckpt_save_dir:保存参数文件的路径。
-
epoch_size:训练执行轮次。
-
dataset_sink_mode:数据是否直接下沉至处理器进行处理。
-
amp_level:混合精度等级。
-
yaml文件样例
model: 'squeezenet1_0' num_classes: 1000 pretrained: False ckpt_path: './squeezenet1_0_gpu.ckpt' keep_checkpoint_max: 10 ckpt_save_dir: './ckpt/' epoch_size: 200 dataset_sink_mode: True amp_level: 'O0' ...
-
parse参数设置
python train.py ... --model squeezenet1_0 --num_classes 1000 --pretrained False \ --ckpt_path ./squeezenet1_0_gpu.ckpt --keep_checkpoint_max 10 \ --ckpt_save_path ./ckpt/ --epoch_size 200 --dataset_sink_mode True \ --amp_level O0 ...
-
对应的代码示例
def train(args): ... network = create_model(model_name=args.model, num_classes=args.num_classes, in_channels=args.in_channels, drop_rate=args.drop_rate, drop_path_rate=args.drop_path_rate, pretrained=args.pretrained, checkpoint_path=args.ckpt_path, ema=args.ema ) ...
损失函数¶
- 参数说明
-
loss:损失函数的简称。
-
label_smoothing:标签平滑值,用于计算Loss时防止模型过拟合的正则化手段。
-
yaml文件样例
loss: 'CE' label_smoothing: 0.1 ...
-
parse参数设置
python train.py ... --loss CE --label_smoothing 0.1 ...
-
对应的代码示例
def train(args): ... loss = create_loss(name=args.loss, reduction=args.reduction, label_smoothing=args.label_smoothing, aux_factor=args.aux_factor ) ...
学习率策略¶
- 参数说明
-
scheduler:学习率策略的名称。
-
min_lr:学习率的最小值。
-
lr:学习率的最大值。
-
warmup_epochs:学习率warmup的轮次。
-
decay_epochs:进行衰减的step数。
-
yaml文件样例
scheduler: 'cosine_decay' min_lr: 0.0 lr: 0.01 warmup_epochs: 0 decay_epochs: 200 ...
-
parse参数设置
python train.py ... --scheduler cosine_decay --min_lr 0.0 --lr 0.01 \ --warmup_epochs 0 --decay_epochs 200 ...
-
对应的代码示例
def train(args): ... lr_scheduler = create_scheduler(num_batches, scheduler=args.scheduler, lr=args.lr, min_lr=args.min_lr, warmup_epochs=args.warmup_epochs, warmup_factor=args.warmup_factor, decay_epochs=args.decay_epochs, decay_rate=args.decay_rate, milestones=args.multi_step_decay_milestones, num_epochs=args.epoch_size, lr_epoch_stair=args.lr_epoch_stair ) ...
优化器¶
- 参数说明
-
opt:优化器名称。
-
weight_decay_filter:权重衰减过滤器 (过滤一些参数, 使其在跟新时不做权重衰减)。
-
momentum:移动平均的动量。
-
weight_decay:权重衰减(L2 penalty)。
-
loss_scale:梯度缩放系数
-
use_nesterov:是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。
-
yaml文件样例
opt: 'momentum' weight_decay_filter: 'norm_and_bias' momentum: 0.9 weight_decay: 0.00007 loss_scale: 1024 use_nesterov: False ...
-
parse参数设置
python train.py ... --opt momentum --weight_decay_filter 'norm_and_bias" --weight_decay 0.00007 \ --loss_scale 1024 --use_nesterov False ...
-
对应的代码示例
def train(args): ... if args.ema: optimizer = create_optimizer(network.trainable_params(), opt=args.opt, lr=lr_scheduler, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=args.use_nesterov, weight_decay_filter=args.weight_decay_filter, loss_scale=args.loss_scale, checkpoint_path=opt_ckpt_path, eps=args.eps ) else: optimizer = create_optimizer(network.trainable_params(), opt=args.opt, lr=lr_scheduler, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=args.use_nesterov, weight_decay_filter=args.weight_decay_filter, checkpoint_path=opt_ckpt_path, eps=args.eps ) ...
Yaml和Parse组合使用¶
使用parse设置参数可以覆盖yaml文件中的参数设置。以下面的shell命令为例,
python train.py -c ./configs/squeezenet/squeezenet_1.0_gpu.yaml --data_dir ./data
上面的命令将args.data_dir
参数的值由yaml文件中的 ./imagenet2012
覆盖为 ./data
。