Efficient Multi-Scale Training训练代码解析

Efficient Multi-Scale Training训练代码解析

  • def parser() 使用argparse模块实现命令行解析[1]
  1. 导入argparse模块
  2. 创建解析器对象ArgumentParser,可添加参数
  3. add_argument(),指定程序需要接受的命令参数

add_argument()定位参数(必选):
parser.add_argument(“echo”, help=“echo the string”)
add_argument()可选参数:
parser.add_argument("–verbosity", help=“increase output verbosity”)

  1. arg_parser.parse_args()
def parser():
    arg_parser = argparse.ArgumentParser('SNIPER training module')
    arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
    							default='configs/faster/pvalite_b5.yml',type=str)
    arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info',
                            default=100, type=int)
    arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float)
    arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
                            default='SNIPER', type=str)
    arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
                            default=None, nargs=argparse.REMAINDER)

    return arg_parser.parse_args()
def main():
	args = parser()
	update_config(args.cfg)
  • mx.gpu
context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]

configs.gpus --cfg参数指定的yml文件中的gpus
Efficient Multi-Scale Training训练代码解析_第1张图片

    nGPUs = len(context) #实用的gpu数量
    batch_size = nGPUs * config.TRAIN.BATCH_IMAGES #设定batch size
  • Create Roidb 创建数据集
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, proposal=config.dataset.proposal, append_gt=True, flip=config.TRAIN.FLIP, result_path=config.output_path, proposal_path=config.proposal_path, load_mask=config.TRAIN.WITH_MASK, only_gt=not config.TRAIN.USE_NEG_CHIPS) for image_set in image_sets]

    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config)

    train_iter = MNIteratorE2E(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs, threads=config.TRAIN.NUM_THREAD, pad_rois_to=400)

  • Create the Logger 创建日志
    logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
  • 获取固定参数列表
    sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))(n_proposals=400, momentum=args.momentum)
    sym = sym_inst.get_symbol_rcnn(config)

    fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym)
  • Create the module 创建模型
    for k in train_iter.provide_data_single:
        print k[0]
    mod = mx.mod.Module(symbol=sym,
                        context=context,
                        data_names=[k[0] for k in train_iter.provide_data_single],
                        label_names=[k[0] for k in train_iter.provide_label_single],
                        fixed_param_names=fixed_param_names)

    shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single)
    sym_inst.infer_shape(shape_dict)
    arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True)
    sym_inst.init_weight_rcnn(config, arg_params, aux_params)
  • Create the metrics 创建指标
    eval_metric = metric.RPNAccMetric()
    cls_metric = metric.RPNLogLossMetric()
    bbox_metric = metric.RPNL1LossMetric()
    rceval_metric = metric.RCNNAccMetric(config)
    rccls_metric  = metric.RCNNLogLossMetric(config)
    rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()

    eval_metrics.add(eval_metric)
    eval_metrics.add(cls_metric)
    eval_metrics.add(bbox_metric)
    eval_metrics.add(rceval_metric)
    eval_metrics.add(rccls_metric)
    eval_metrics.add(rcbbox_metric)

	optimizer_params = get_optim_params(config, len(train_iter), batch_size)
  • Checkpoint
    prefix = os.path.join(output_path, args.save_prefix)
    batch_end_callback = mx.callback.Speedometer(batch_size, args.display)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          eval('{}.checkpoint_callback'.format(config.symbol))(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)]

    train_iter = PrefetchingIter(train_iter)
    mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params,
            eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params)

参考博文
[1]Python命令行解析argparse常用语法使用简介

你可能感兴趣的:(机器学习,计算机视觉,无人驾驶,计算机视觉)