tools/dist_train.sh projects/configs/obj_dgcnn/pillar.py 8
#!/usr/bin/env bash
# DETR3D传入config_path,gpus,port为默认
CONFIG=$1
GPUS=$2
PORT=${PORT:-28500}
# 这里的distributed为单机多卡训练模式,需要指定gpus,port,train.py,如果是多机多卡必须要指定节点个数与rank等参数
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
关于torch.distributed.launch的更多细节:https://blog.csdn.net/magic_ll/article/details/122359490
设置config file和work dir,work dir保存最终config,log等信息,work dir默认为path/to/user/work_dir/
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
'''
省略一部分
'''
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both specified, '
'--options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
作者将自定义的部分放在 ‘projects/mmdet3d_plugin/’ 文件夹下,通过registry类注册模块,这里利用importlib导入模块并初始化自定义的类。
args = parse_args()
cfg = Config.fromfile(args.config)
# 从args更新读取的config文件,args优先级>cfg的优先级,args定义了cfg文件中没有定义的work_dir等参数,还有一部分需要覆盖cfg的参数
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from plguin/xx, registry will be updated
if hasattr(cfg, 'plugin'):
if cfg.plugin:
# 将plugin批量导入模型环境
# plugin_dir='projects/mmdet3d_plugin/'
import importlib
if hasattr(cfg, 'plugin_dir'):
plugin_dir = cfg.plugin_dir
# _module_dir = 'projects/mmdet3d_plugin'
_module_dir = os.path.dirname(plugin_dir)
_module_dir = _module_dir.split('/')
_module_path = _module_dir[0]
# 将目录转化为python中库层级.的形式
for m in _module_dir[1:]:
_module_path = _module_path + '.' + m
plg_lib = importlib.import_module(_module_path)
这里设置模型的输出信息保存路径、gpus等模型的运行时环境参数
# 加载old config
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config and save to work_dir
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
if cfg.model.type in ['EncoderDecoder3D']:
logger_name = 'mmseg'
else:
logger_name = 'mmdet'
logger = get_root_logger(
log_file=log_file, log_level=cfg.log_level, name=logger_name)
# meta:保存环境信息、随机种子等
meta = dict()
# log env info
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
这里初始化模型,初始化train_dataset和val_dataset
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
# dataset initialization,input: pipeline,class_names,modality
# 返回train Dataset用于后面的Dataloader
datasets = [build_dataset(cfg.data.train)]
# 设置cfg_workflow=[['train',1],['val',1]]:每train一个epoch后测试验证集:代码省略
# set checkpoint:代码省略
# 初始化config,hook,dataloader,runner, 然后运行runner开始按照workflow开始训练
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
这部分完成了DataLoader的初始化,runner和hooks的初始化,并且按照workflow运行runner。
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
cfg = compat_cfg(cfg)
logger = get_root_logger(log_level=cfg.log_level)
# prepare data loaders,dataset可能是列表,也可能是单独一个,因为workflow=2时包含val的dataset
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# runner加载,default runner:EpochBasedRunner
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
train_dataloader_default_args = dict(
samples_per_gpu=2,
workers_per_gpu=2,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=False)
# 更新dataloader的参数设置,结合上面的设置和configfile里面的设置
train_loader_cfg = {
**train_dataloader_default_args,
**cfg.data.get('train_dataloader', {}) # update dataloder_cfg from cfg files, if there is no train_dataloader, set this to {}
}
# 创建dataloader
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
# build optimizer
auto_scale_lr(cfg, distributed, logger)
optimizer = build_optimizer(model, cfg.optimizer)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# register training hook
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
# register eval hooks
if validate:
val_dataloader_default_args = dict(
samples_per_gpu=1,
workers_per_gpu=2,
dist=distributed,
shuffle=False,
persistent_workers=False)
val_dataloader_args = {
**val_dataloader_default_args,
**cfg.data.get('val_dataloader', {})
}
# Support batch_size > 1 in validation
val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
# resume from last model
resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
# run runner iteratively
runner.run(data_loaders, cfg.workflow)