add_argument()定位参数(必选):
parser.add_argument(“echo”, help=“echo the string”)
add_argument()可选参数:
parser.add_argument("–verbosity", help=“increase output verbosity”)
def parser():
arg_parser = argparse.ArgumentParser('SNIPER training module')
arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
default='configs/faster/pvalite_b5.yml',type=str)
arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info',
default=100, type=int)
arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float)
arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
default='SNIPER', type=str)
arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
default=None, nargs=argparse.REMAINDER)
return arg_parser.parse_args()
def main():
args = parser()
update_config(args.cfg)
context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]
configs.gpus --cfg参数指定的yml文件中的gpus
nGPUs = len(context) #实用的gpu数量
batch_size = nGPUs * config.TRAIN.BATCH_IMAGES #设定batch size
image_sets = [iset for iset in config.dataset.image_set.split('+')]
roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, proposal=config.dataset.proposal, append_gt=True, flip=config.TRAIN.FLIP, result_path=config.output_path, proposal_path=config.proposal_path, load_mask=config.TRAIN.WITH_MASK, only_gt=not config.TRAIN.USE_NEG_CHIPS) for image_set in image_sets]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb, config)
bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config)
train_iter = MNIteratorE2E(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs, threads=config.TRAIN.NUM_THREAD, pad_rois_to=400)
logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))(n_proposals=400, momentum=args.momentum)
sym = sym_inst.get_symbol_rcnn(config)
fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym)
for k in train_iter.provide_data_single:
print k[0]
mod = mx.mod.Module(symbol=sym,
context=context,
data_names=[k[0] for k in train_iter.provide_data_single],
label_names=[k[0] for k in train_iter.provide_label_single],
fixed_param_names=fixed_param_names)
shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single)
sym_inst.infer_shape(shape_dict)
arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True)
sym_inst.init_weight_rcnn(config, arg_params, aux_params)
eval_metric = metric.RPNAccMetric()
cls_metric = metric.RPNLogLossMetric()
bbox_metric = metric.RPNL1LossMetric()
rceval_metric = metric.RCNNAccMetric(config)
rccls_metric = metric.RCNNLogLossMetric(config)
rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(eval_metric)
eval_metrics.add(cls_metric)
eval_metrics.add(bbox_metric)
eval_metrics.add(rceval_metric)
eval_metrics.add(rccls_metric)
eval_metrics.add(rcbbox_metric)
optimizer_params = get_optim_params(config, len(train_iter), batch_size)
prefix = os.path.join(output_path, args.save_prefix)
batch_end_callback = mx.callback.Speedometer(batch_size, args.display)
epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
eval('{}.checkpoint_callback'.format(config.symbol))(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)]
train_iter = PrefetchingIter(train_iter)
mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params,
eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore,
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params)
参考博文
[1]Python命令行解析argparse常用语法使用简介