mmaction2 训练相关源码概览

文章目录

    • 0. 前言
    • 1. 入口函数详解
      • 1.1. 构建模型
      • 1.2. 构建数据集
      • 1.3. 执行训练。
    • 2. Runner 介绍
      • 2.1. BaseRunner
      • 2.2. EpochBaseRunner


0. 前言

  • 目标:整理 mmaction2 训练的实现过程。

1. 入口函数详解

  • 入口代码:tools/train.py
  • 实现的功能:
    • 第一步:构建参数,包括命令行参数以及配置文件参数。
    • 第二步:初始化一堆东西,比如创建输出路径、logger、random seed等。
    • 第三步:构建模型。
    • 第四步:构建数据集。
    • 第五步:执行训练。

1.1. 构建模型

  • 入口函数:mmaction/models/builder.py 中的 build_model 函数。
  • 实现的功能:
    • 根据配置文件中的 cfg.model['type'] 判断模型类型。
    • 通过注册机制,根据模型类型字符串选择对应的类。
    • 通过 cfg.model 中除 type 外的其他参数作为模型初始化参数,构建模型对象。
  • 更多内容请参考 我的笔记

1.2. 构建数据集

  • 入口函数:mmaction/datasets/builder.py 中的 build_dataset 函数。
  • 实现功能:
    • 根据配置文件中的 cfg.data.train['type'] 判断数据集类型。
    • 通过注册机制,根据数据集类型字符串选择对应的类。
    • 通过 cfg.data.train 中除 type 外的其他参数作为数据集初始化参数,构建最终数据集。
  • 更多内容请参考 我的笔记

1.3. 执行训练。

  • 入口函数:mmaction/apis/train.py 中的 train_model 函数。
  • 从流程看:
    • 第一步:构建logger
    • 第二步:根据参数构建dataloader
    • 第三步:根据需求,构建分布式模型
    • 第四步:构建optimizer
    • 第五步:初始化 EpochBaseRunner,简称为 runner。
    • 第六步:根据需求设置 fp16 量化。
    • 第七步:构建各类hooks,包括学习率、log、优化器、保存模型、分布式sampler等。
    • 第八步:根据需求设置validate参数,包括构建val数据集以及对应dataloader,以及对应的 eval hook。
    • 第九步:根据需求初始化模型参数。
    • 第十步:实际执行训练。
  • 从实现机制看:
    • 训练细节都是通过 EpochBaseRunner 实现的。
    • 一些具体细节都是通过 runner 中的hook实现。

2. Runner 介绍

2.1. BaseRunner

  • 代码位于 mmcv.runner.base_runner.py
  • 作用:pytorch训练相关代码。
  • 构造函数:
    • 输入参数
      • model
      • batch_processorcallable方法,调用方法是 batch_processor(model, data, train_mode),输出一个字典)
      • optimizer
      • work_dir(保存模型、日志文件)
      • logger
      • meta(字典,包括环境信息和seed等)
    • 执行的操作:
      • 判断输入参数合法性。
      • 将输入数据保存为成员变量。
      • 初始化其他成员变量。
  • 支持的成员变量:model_namerankworld_sizehooksepochiterinner_itermax_epochsmax_iters
  • 抽象方法:trainvalrunsave_checkpoint
  • 实现的功能
    • 获取optimizer中每个param_groups中lr、momentum、betas的数值。
    • 注册hook、定义一些默认hook。
    • resume 模型权重。
  • hooks 相关功能:
    • register_hook 时会根据输入的 priority 获得具体的优先级数值,内部保存hooks时会根据优先级数值进行排序。
    • 定义 hook 的 helper function,用来运行所有hook的某个方法。
    • 定义training中默认用到的六种hook
      • LrUpdaterHook:详见 mmcv.runner.hooks.lr_updater.py
      • MomentumUpdaterHook:optimizer中momentum的更新,详见 mmcv.runner.hooks.momentum_updater.py
      • OptimizerStepperHook:更新参数hook,详见 mmcv.runner.hooks.optimizer.py
      • CheckpointSaverHook:保存模型hook,详见 mmcv.runner.hooks.checkpoint.py
      • IterTimerHook:为logger增加计时功能,在logger中增加了两个参数data_timetime,前者表示数据获取时间,后者表示iter总时间,详见 mmcv.runner.hooks.iter_timer.py
      • LoggerHook(s):详见 mmcv.runner.hooks.logger

2.2. EpochBaseRunner

  • 定义了BaseRunner中 train/val/run/save_checkpoint 四个抽象方法。
  • 训练代码主要功能:
    • 在响应位置调用hooks对应的方法,这里就包含了更新参数、logger、更新学习率、模型保存等功能。
    • 遍历一遍dataloader,分别执行前向过程,得到损失函数。
      • 如果设置了 batch_processor,则通过该函数计算损失函数。
      • 如果没有设置 batch_processor,则通过 model.train_step 获得损失函数。
def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        if self.batch_processor is None:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=True, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            ' must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'],
                                   outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1
  • 验证相关代码,主要功能包括:
    • 执行各类hooks。
    • 遍历 dataloader,通过 self.batch_processormodel.val_step 执行前向操作,得到模型输出结果
def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_val_iter')
        with torch.no_grad():
            if self.batch_processor is None:
                outputs = self.model.val_step(data_batch, self.optimizer,
                                              **kwargs)
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=False, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.val_step()"'
                            ' must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'],
                                   outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_val_iter')

    self.call_hook('after_val_epoch')
  • run,包括训练/验证工作
    • 输入参数包括 data_loaders/workflow,两者的长度相同,分别对应。
      • workflow 加入是 [('train', 2), ('val', 1)],则表示train 2 epoch then val 1 epoch,按照这个顺序依次进行训练,作为一个epoch。
      • 后续会根据 workflow 根据 mode 选择对应的 train/val 方法。
def run(self, data_loaders, workflow, max_epochs, **kwargs):
    """Start running.

    Args:
        data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
            and validation.
        workflow (list[tuple]): A list of (phase, epochs) to specify the
            running order and epochs. E.g, [('train', 2), ('val', 1)] means
            running 2 epochs for training and 1 epoch for validation,
            iteratively.
        max_epochs (int): Total training epochs.
    """
    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)

    self._max_epochs = max_epochs
    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    work_dir = self.work_dir if self.work_dir is not None else 'NONE'
    self.logger.info('Start running, host: %s, work_dir: %s',
                     get_host_info(), work_dir)
    self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
    self.call_hook('before_run')

    while self.epoch < max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if isinstance(mode, str):  # self.train()
                if not hasattr(self, mode):
                    raise ValueError(
                        f'runner has no method named "{mode}" to run an '
                        'epoch')
                epoch_runner = getattr(self, mode)
            else:
                raise TypeError(
                    'mode in workflow must be a str, but got {}'.format(
                        type(mode)))

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= max_epochs:
                    return
                epoch_runner(data_loaders[i], **kwargs)

    time.sleep(1)  # wait for some hooks like loggers to finish
    self.call_hook('after_run')

你可能感兴趣的:(CV)