(十七)mmdetection源码解读:EpochBasedRunner

目录

  • 一、run
  • 二、train
  • 三、val
  • 四、run_iter
  • 五、save_checkpoint

一、run

runner.run(data_loaders, cfg.workflow)

run 方法调用后才是真正开启工作流
workflow = [(‘train’, 1)],表示只运行训练工作流
workflow = [(‘train’, 2), (‘val’,1)],表示先训练2个 epoch ,然后切换到 val 工作流,运行 1 个 epoch,然后循环,直到训练 epoch 次数达到指定值
workflow = [(‘val’, 1), (‘train’,1)],表示先验证1 个 epoch, 再训练1 个 epoch

run 方法中定义的是通用工作流切换流程,真正完成一个 epoch 工作流是调用了工作流函数。目前支持 train 和 val 两个工作流,那么 epoch_runner(data_loaders[i], **kwargs) 调用的实际上是 train 或者 val 方法:

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):

    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)
    if max_epochs is not None:
        warnings.warn(
            'setting max_epochs in run is deprecated, '
            'please set max_epochs in runner_config', DeprecationWarning)
        self._max_epochs = max_epochs

    assert self._max_epochs is not None, (
        'max_epochs must be specified during instantiation')

    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    work_dir = self.work_dir if self.work_dir is not None else 'NONE'
    self.logger.info('Start running, host: %s, work_dir: %s',
                     get_host_info(), work_dir)
    self.logger.info('Hooks will be executed in the following order:\n%s',
                     self.get_hook_info())
    self.logger.info('workflow: %s, max: %d epochs', workflow,
                     self._max_epochs)
    self.call_hook('before_run')

    while self.epoch < self._max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if isinstance(mode, str):  # self.train()
                if not hasattr(self, mode):
                    raise ValueError(
                        f'runner has no method named "{mode}" to run an '
                        'epoch')
                epoch_runner = getattr(self, mode)
            else:
                raise TypeError(
                    'mode in workflow must be a str, but got {}'.format(
                        type(mode)))

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break
                epoch_runner(data_loaders[i], **kwargs)

    time.sleep(1)  # wait for some hooks like loggers to finish
    self.call_hook('after_run')

二、train

遍历 data_loader,然后进行 batch 级别的迭代训练。真正完成一个 batch 的训练是调用了self.run_iter(data_batch, train_mode=True, **kwargs),同时调用了四次call_hook函数:
self.call_hook(‘before_train_epoch’)
self.call_hook(‘before_train_iter’)
self.call_hook(‘after_train_iter’)
self.call_hook(‘after_train_epoch’)

def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True, **kwargs)
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1

三、val

遍历 data_loader,然后进行 batch 级别的迭代验证。真正完成一个 batch 的验证是调用了 self.run_iter(data_batch, train_mode=False),同时调用了四次call_hook函数
self.call_hook(‘before_val_epoch’)
self.call_hook(‘before_val_iter’)
self.call_hook(‘after_val_iter’)
self.call_hook(‘after_val_epoch’)

@torch.no_grad()
def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_val_iter')
        self.run_iter(data_batch, train_mode=False)
        self.call_hook('after_val_iter')

    self.call_hook('after_val_epoch')

四、run_iter

从上面train 和 val函数中可以看到,真正完成一个 batch 的训练或者验证是调用了 self.run_iter(),在函数run_iter()中调用 model 自身的 train_step 或者 val_step 方法

def run_iter(self, data_batch, train_mode, **kwargs):
    if self.batch_processor is not None:
        outputs = self.batch_processor(
            self.model, data_batch, train_mode=train_mode, **kwargs)
    elif train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer,
                                        **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
    if not isinstance(outputs, dict):
        raise TypeError('"batch_processor()" or "model.train_step()"'
                        'and "model.val_step()" must return a dict')
    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
    self.outputs = outputs

五、save_checkpoint

保存权重参数

def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        if meta is None:
            meta = {}
        elif not isinstance(meta, dict):
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)
            # Note: meta.update(self.meta) should be done before
            # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
            # there will be problems with resumed checkpoints.
            # More details in https://github.com/open-mmlab/mmcv/pull/1108
        meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, 'latest.pth')
            if platform.system() != 'Windows':
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filepath, dst_file)

你可能感兴趣的:(mmdtection,python,pytorch,目标检测)