一般完整的网络训练过程包含几个部分:
前面四个过程是在main和train_detector这两个函数中实现的,全部是采用Registry工厂模式实现的(具体可参考here1和here2),在这里根据cfg build出来。(另外,在mmdetection中design loss是作为build model的一部分实现的)
它们的代码简要如下:
# tools/train.py
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# 1. merge cfg; 2. create work_dir 3. create distribution train world with given GPU ids 4. collect env ...
...
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
...
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
...
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
...
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
...
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
...
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
...
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
...
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
上面五个步骤(在mmdetection中往往将design loss作为build model的一部分)中前四个步骤在上一个章节已经知道了其创建方式,而最后一个步骤的做法很多操作基本是固定的,因此可以把最后一个操作封装起来,同时允许在每个关键位置上添加Hook(回调函数)的方式来实现拓展,这两个事情就是Runner要实现的功能。下面以默认最常用的EpochBasedRunner为例进行说明,Runner的函数调用如下所示,以run为入口函数,里面调用train/val,它们各自调用run_iter函数。
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
...
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
...
epoch_runner = getattr(self, mode)
...
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
...
上面代码里的self.call_hook就是在调用对应的hook函数们。
self.call_hook('before_run')
self.call_hook('before_train_epoch') / self.call_hook('before_val_epoch')
self.call_hook('before_train_iter') / self.call_hook('before_val_iter')
self.call_hook('after_train_iter') / self.call_hook('after_val_iter')
self.call_hook('after_train_epoch') / self.call_hook('after_val_epoch')
self.call_hook('after_run')
Runner中关于Hook的注册与调用都是在Hook的基类BaseRunner中实现,对应函数register_hook和call_hook
class BaseRunner(metaclass=ABCMeta):
...
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
"""Register a hook from its cfg.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
def get_hook_info(self):
...
def resume(self,
...
def register_lr_hook(self, lr_config):
if lr_config is None:
return
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook, priority='VERY_HIGH')
Hook类的定义如下,是插个各个位置的回调函数的类,我们可以根据自己的需要继承Hook类并重写对应位置的函数
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return [stage for stage in Hook.stages if stage in trigger_stages]
在mmdet/apis/train.py: train_detector(该函数被训练的入口函数 train.py: main函数调用)中有以下相关的实现:
# mmdet/apis/train.py
def train_detector(
...
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
...
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)