文章目录
- 0. 前言
- 1. 入口函数详解
- 1.1. 构建模型
- 1.2. 构建数据集
- 1.3. 执行训练。
- 2. Runner 介绍
- 2.1. BaseRunner
- 2.2. EpochBaseRunner
0. 前言
1. 入口函数详解
- 入口代码:
tools/train.py
- 实现的功能:
- 第一步:构建参数,包括命令行参数以及配置文件参数。
- 第二步:初始化一堆东西,比如创建输出路径、logger、random seed等。
- 第三步:构建模型。
- 第四步:构建数据集。
- 第五步:执行训练。
1.1. 构建模型
- 入口函数:
mmaction/models/builder.py
中的 build_model
函数。
- 实现的功能:
- 根据配置文件中的
cfg.model['type']
判断模型类型。
- 通过注册机制,根据模型类型字符串选择对应的类。
- 通过
cfg.model
中除 type
外的其他参数作为模型初始化参数,构建模型对象。
- 更多内容请参考 我的笔记
1.2. 构建数据集
- 入口函数:
mmaction/datasets/builder.py
中的 build_dataset
函数。
- 实现功能:
- 根据配置文件中的
cfg.data.train['type']
判断数据集类型。
- 通过注册机制,根据数据集类型字符串选择对应的类。
- 通过
cfg.data.train
中除 type
外的其他参数作为数据集初始化参数,构建最终数据集。
- 更多内容请参考 我的笔记
1.3. 执行训练。
- 入口函数:
mmaction/apis/train.py
中的 train_model
函数。
- 从流程看:
- 第一步:构建logger
- 第二步:根据参数构建dataloader
- 第三步:根据需求,构建分布式模型
- 第四步:构建optimizer
- 第五步:初始化
EpochBaseRunner
,简称为 runner。
- 第六步:根据需求设置 fp16 量化。
- 第七步:构建各类hooks,包括学习率、log、优化器、保存模型、分布式sampler等。
- 第八步:根据需求设置validate参数,包括构建val数据集以及对应dataloader,以及对应的 eval hook。
- 第九步:根据需求初始化模型参数。
- 第十步:实际执行训练。
- 从实现机制看:
- 训练细节都是通过
EpochBaseRunner
实现的。
- 一些具体细节都是通过 runner 中的hook实现。
2. Runner 介绍
2.1. BaseRunner
- 代码位于
mmcv.runner.base_runner.py
中
- 作用:pytorch训练相关代码。
- 构造函数:
- 输入参数
- model
batch_processor
(callable
方法,调用方法是 batch_processor(model, data, train_mode)
,输出一个字典)
- optimizer
work_dir
(保存模型、日志文件)
- logger
- meta(字典,包括环境信息和seed等)
- 执行的操作:
- 判断输入参数合法性。
- 将输入数据保存为成员变量。
- 初始化其他成员变量。
- 支持的成员变量:
model_name
、rank
、world_size
、hooks
、epoch
、iter
、inner_iter
、max_epochs
、max_iters
- 抽象方法:
train
、val
、run
、save_checkpoint
- 实现的功能
- 获取optimizer中每个
param_groups
中lr、momentum、betas的数值。
- 注册hook、定义一些默认hook。
- resume 模型权重。
- hooks 相关功能:
- 在
register_hook
时会根据输入的 priority 获得具体的优先级数值,内部保存hooks时会根据优先级数值进行排序。
- 定义 hook 的 helper function,用来运行所有hook的某个方法。
- 定义training中默认用到的六种hook
LrUpdaterHook
:详见 mmcv.runner.hooks.lr_updater.py
MomentumUpdaterHook
:optimizer中momentum的更新,详见 mmcv.runner.hooks.momentum_updater.py
OptimizerStepperHook
:更新参数hook,详见 mmcv.runner.hooks.optimizer.py
CheckpointSaverHook
:保存模型hook,详见 mmcv.runner.hooks.checkpoint.py
IterTimerHook
:为logger增加计时功能,在logger中增加了两个参数data_time
和time
,前者表示数据获取时间,后者表示iter总时间,详见 mmcv.runner.hooks.iter_timer.py
LoggerHook(s)
:详见 mmcv.runner.hooks.logger
中
2.2. EpochBaseRunner
- 定义了BaseRunner中
train/val/run/save_checkpoint
四个抽象方法。
- 训练代码主要功能:
- 在响应位置调用hooks对应的方法,这里就包含了更新参数、logger、更新学习率、模型保存等功能。
- 遍历一遍dataloader,分别执行前向过程,得到损失函数。
- 如果设置了
batch_processor
,则通过该函数计算损失函数。
- 如果没有设置
batch_processor
,则通过 model.train_step
获得损失函数。
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
- 验证相关代码,主要功能包括:
- 执行各类hooks。
- 遍历 dataloader,通过
self.batch_processor
或 model.val_step
执行前向操作,得到模型输出结果
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2)
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
- run,包括训练/验证工作
- 输入参数包括
data_loaders/workflow
,两者的长度相同,分别对应。
workflow
加入是 [('train', 2), ('val', 1)]
,则表示train 2 epoch then val 1 epoch,按照这个顺序依次进行训练,作为一个epoch。
- 后续会根据
workflow
根据 mode 选择对应的 train/val
方法。
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str):
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1)
self.call_hook('after_run')