MMdetection之train_detector 源码解析

目录

(一)构建 data loaders(mmdet/datasets/builder.py)

(2)构建分布式处理对象

(3)构建优化器

(4)创建 EpochBasedRunner 并进行训练


(一)构建 data loaders(mmdet/datasets/builder.py)

其主要步骤是创建采样器,并将采样器,collate 函数, worker_init_fn 函数传入 DataLoader 中,用于创建 pytorch dataloader。

def build_dataloader(dataset,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     shuffle=True,
                     seed=None,
                     **kwargs):
    # 获取进程编号和总进程数
    rank, world_size = get_dist_info()
    # 如果是分布式训练, 即使用 dist_train.sh 会进入此 if.
    if dist:
        # DistributedGroupSampler 会进行 shuffle, 而且会保证每个 GPU 的样本都是同一组的.
        if shuffle:
            sampler = DistributedGroupSampler(dataset, samples_per_gpu,
                                              world_size, rank)
        # 不 shuffle, 使用 torch.utils.data 中的 DistributedSampler
        # 因为 pytorch < 1.2 没有 shuffle 形参, 
        # 为了版本适配, 重写了一个 DistributedSampler
        else:
            sampler = DistributedSampler(
                dataset, world_size, rank, shuffle=False)
        batch_size = samples_per_gpu
        num_workers = workers_per_gpu
    # 不是分布式训练, 即直接使用 train.py 进行训练.
    else:
        sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
        batch_size = num_gpus * samples_per_gpu
        num_workers = num_gpus * workers_per_gpu

    init_fn = partial(
        worker_init_fn, num_workers=num_workers, rank=rank,
        seed=seed) if seed is not None else None

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
        pin_memory=False,
        worker_init_fn=init_fn,
        **kwargs)

    return data_loader


def worker_init_fn(worker_id, num_workers, rank, seed):
    # The seed of each worker equals to
    # num_worker * rank + worker_id + user_seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)

采样器总共有三种:DistributedGroupSampler,DistributedSampler,GroupSampler。

DistributedGroupSampler 和 GroupSampler 用于按组进行分布式采样(注意:在 MMDetection 中将图像宽大于高和宽小于高分为两组,每个 GPU 中的图像应该取自同一组)

DistributedSampler 是 pytorch 自带分布式采样的重写类,因为 pytorch < 1.2 没有 shuffle 参数。也就是说,pytorch < 1.2 的 DistributedSampler 只支持顺序分布式采样。而 pytorch >= 1.2 支持顺序和乱序采样。为了版本适配和接口统一,重写了一个 DistributedSampler,将 shuffle 始终设置为 False。

① GroupSampler(datasets/samplers/group_samples.py)

from __future__ import division
import math

import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler


class GroupSampler(Sampler):
    def __init__(self, dataset, samples_per_gpu=1):
        # 如果图片的  宽 > 高, 记为 为 1
        #            宽 < 高, 记为 为 0
        # flag 是一个记录了数据集中所有图片的 ndarray
        assert hasattr(dataset, 'flag')
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.flag = dataset.flag.astype(np.int64)
        # np.bincount 计算每个索引出现的次数
        # 在这里就相当于计算了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
        self.group_sizes = np.bincount(self.flag)
        self.num_samples = 0
        for i, size in enumerate(self.group_sizes):
            # 保证每组的 sample 数都能被 samples_per_gpu 的数量整除
            self.num_samples += int(np.ceil(
                size / self.samples_per_gpu)) * self.samples_per_gpu

    def __iter__(self):
        indices = []
        for i, size in enumerate(self.group_sizes):
            # 如果数据集中的所有的图片的宽都 < 高, 那么进行下一次循环.
            if size == 0:
                continue
            # 找到 宽 < 高(i = 0) 或 宽 > 高(i = 1) 的所有的图片索引
            indice = np.where(self.flag == i)[0]
            assert len(indice) == size
            # 随机打乱索引
            np.random.shuffle(indice)
            # 因为图片个数不一定会被 samples_per_gpu 整除, 所以添加额外的数据.
            # num_extra 即为添加额外数据的数量.
            num_extra = int(np.ceil(size / self.samples_per_gpu)
                            ) * self.samples_per_gpu - len(indice)
            # np.concatenate(需要concat的list, axis=0)
            # np.random.choice(list, 选的size)
            # 生成所有的 index
            indice = np.concatenate(
                [indice, np.random.choice(indice, num_extra)])
            indices.append(indice)
        # 整合所有的 index
        indices = np.concatenate(indices)
        # 如下操作可以保证每个 samples_per_gpu 的 flag 都相同
        indices = [
            indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
            for i in np.random.permutation(
                range(len(indices) // self.samples_per_gpu))
        ]
        indices = np.concatenate(indices)
        indices = indices.astype(np.int64).tolist()
        assert len(indices) == self.num_samples
        return iter(indices)

    def __len__(self):
        return self.num_samples

② DistributedGroupSampler(mmset/datasets/samplers/group_samples.py)

class DistributedGroupSampler(Sampler):
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        # 获取 rank 和 world_size (num_replicas)
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank

        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu

        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        # 统计了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
        self.group_sizes = np.bincount(self.flag)

        # 每个进程需要采样的样本数
        self.num_samples = 0

        for i, j in enumerate(self.group_sizes):
            # self.group_sizes[i] / self.samples_per_gpu:能分成多少组
            # 下面的代表计算了每个机器分的个数.
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        # 所有进程要采样的样本总数。
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # 把当前的 epoch 作为随机数种子,
        # 这样能保证在相同的 epoch 的实验有可重复性,
        # 且在不同的 epoch 之间有随机性.
        g = torch.Generator()
        g.manual_seed(self.epoch)

        indices = []
        for i, size in enumerate(self.group_sizes):
            # 如果有样本
            if size > 0:
                # 找出所有属于此类的索引
                indice = np.where(self.flag == i)[0]
                assert len(indice) == size
                # 随机打乱索引
                indice = indice[list(torch.randperm(int(size),
                                                    generator=g))].tolist()
                # 总共需要额外添加的样本数
                extra = int(
                    math.ceil(
                        size * 1.0 / self.samples_per_gpu / self.num_replicas)
                ) * self.samples_per_gpu * self.num_replicas - len(indice)

                # 填充 indice
                tmp = indice.copy()
                for _ in range(extra // size):
                    indice.extend(tmp)
                # 取随机后的前 extra 个作为 extra 样本.
                indice.extend(tmp[:extra % size])
                indices.extend(indice)

        assert len(indices) == self.total_size

        # 打乱 sample_per_gpu 之间的顺序,
        # 因为上面已经打乱了每个 group 之内的元素,
        # 所以这里只用打乱组之间的顺序即可.
        indices = [
            indices[j] for i in list(
                torch.randperm(
                    len(indices) // self.samples_per_gpu, generator=g))
            for j in range(i * self.samples_per_gpu, (i + 1) *
                           self.samples_per_gpu)
        ]

        # 采样 num_samples 个.不同进程之间按照打乱的数据集顺序采样.
        offset = self.num_samples * self.rank
        indices = indices[offset:offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

③ DistributedSampler(datasets/samplers/distributed_sampler.py)

import torch
from torch.utils.data import DistributedSampler as _DistributedSampler


# pytorch < 1.2 没有 shuffle, 为了版本适配, 这里选择重写
class DistributedSampler(_DistributedSampler):

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
        self.shuffle = shuffle

    def __iter__(self):
        # 把当前的 epoch 作为随机数种子,
        # 这样能保证在相同的 epoch 的实验有可重复性,
        # 且在不同的 epoch 之间有随机性.
        if self.shuffle:
            # 使用随机数生成器, 根据 epoch 生成随机数种子.
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()

        # 添加额外的样本使其均匀可分
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

collate 函数用于整理与合并每个 batch 的数据。并将这个函数传给 DataLoader 的 collate_fn 参数。

from collections.abc import Mapping, Sequence

import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

from .data_container import DataContainer


def collate(batch, samples_per_gpu=1):
    """Puts each data field into a tensor/DataContainer with outer dimension
    batch size.

    Extend default_collate to add support for
    :type:`~mmcv.parallel.DataContainer`. There are 3 cases.

    1. cpu_only = True, e.g., meta data
    2. cpu_only = False, stack = True, e.g., images tensors
    3. cpu_only = False, stack = False, e.g., gt bboxes
    """
    # batch 是一个长度为 batch_size 的列表, 每个元素是一个字典, 每个字典代表一张图片.
    # 字典的键为:dict_keys(['img_metas', 'img', 'gt_bboxes', 'gt_labels'])

    # 确保 batch 是一个序列
    if not isinstance(batch, Sequence):
        raise TypeError(f'{batch.dtype} is not supported.')

    if isinstance(batch[0], DataContainer):
        assert len(batch) % samples_per_gpu == 0
        stacked = []

        # cpu_only 说明是 meta data
        if batch[0].cpu_only:
            # batch[0].stack:           False
            # batch[0].padding_value:   0
            for i in range(0, len(batch), samples_per_gpu):
                # 每 samples_per_gpu 个, 创建一个列表
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
            # 转成 DataContainer 对象
            return DataContainer(
                stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
        # stack 为 True 说明是图片类型的数据 或 label 数据
        elif batch[0].stack:
            for i in range(0, len(batch), samples_per_gpu):
                assert isinstance(batch[i].data, torch.Tensor)
                # 需要填充维度
                if batch[i].pad_dims is not None:
                    ndim = batch[i].dim()
                    assert ndim > batch[i].pad_dims
                    max_shape = [0 for _ in range(batch[i].pad_dims)]
                    for dim in range(1, batch[i].pad_dims + 1):
                        max_shape[dim - 1] = batch[i].size(-dim)
                    for sample in batch[i:i + samples_per_gpu]:
                        for dim in range(0, ndim - batch[i].pad_dims):
                            assert batch[i].size(dim) == sample.size(dim)
                        for dim in range(1, batch[i].pad_dims + 1):
                            max_shape[dim - 1] = max(max_shape[dim - 1],
                                                     sample.size(-dim))
                    padded_samples = []
                    for sample in batch[i:i + samples_per_gpu]:
                        pad = [0 for _ in range(batch[i].pad_dims * 2)]
                        for dim in range(1, batch[i].pad_dims + 1):
                            pad[2 * dim -
                                1] = max_shape[dim - 1] - sample.size(-dim)
                        padded_samples.append(
                            F.pad(
                                sample.data, pad, value=sample.padding_value))
                    stacked.append(default_collate(padded_samples))
                # 不填充维度
                elif batch[i].pad_dims is None:
                    stacked.append(
                        default_collate([
                            sample.data
                            for sample in batch[i:i + samples_per_gpu]
                        ]))
                else:
                    raise ValueError(
                        'pad_dims should be either None or integers (1-3)')
        # 说明是 gt bboxes
        else:
            # 取 samples_per_gpu 个, 创建列表返回.
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
        return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
    # 是序列
    elif isinstance(batch[0], Sequence):
        transposed = zip(*batch)
        return [collate(samples, samples_per_gpu) for samples in transposed]
    
    # 最开始传入的是一个字典, 里面有 图像属性, 图像, gt, label 等信息.
    # 所以会先进入下面的 if
    elif isinstance(batch[0], Mapping):
        # 返回一个字典: 每个 key 的值是原来所有 key 的值的 collate 后的结果
        return {
            key: collate([d[key] for d in batch], samples_per_gpu)
            # 遍历每个 key
            for key in batch[0]
        }
    # 采用默认的整理方式
    else:
        return default_collate(batch)

每个线程的随机数种子默认为线程ID,每次运行时随机数种子不固定。考虑到实验的可重复性,创建一个 worker_init_fn 函数传参给 DataLoader 中的 worker_init_fn 参数,此参数是 worker 的初始化函数。将 num_worker * rank + worker_id + user_seed 作为随机数种子,可以解决每个线程中随机数种子不确定的情况。

def worker_init_fn(worker_id, num_workers, rank, seed):
    # 将 num_worker * rank + worker_id + user_seed 作为随机数种子
    # 可以解决线程之间随机数种子不固定的情况
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)

(2)构建分布式处理对象

MMDetection 对 pytorch 的 DistributedDataParallel 和 DataParallel 在外面有一层封装,重写了 scatter 方法,额外实现了 train_step 和 val_step 方法。scatter 方法用于将数据分发到指定的 GPU,train_step 和 val_step 对于传入的一个 batch 的数据,会调用 Detector 的 train_step 或 val_step 计算损失或得到模型输出值。(MMDetection 所有 Detector 都有 train_step 和 val_step 方法,这样在训练的时候就不需要传入损失函数来计算损失了,不同的模型可以使用不同的损失,同一个模型也可以使用不同的损失函数。这样更灵活)

如果单机多卡会使用 MMDistributedDataParallel 构建对象。如果单机单卡会使用 MMDataParallel 构建对象。

① MMDistributedDataParallel

# Copyright (c) Open-MMLab. All rights reserved.
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
                                           _find_tensors)

from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs


class MMDistributedDataParallel(DistributedDataParallel):
    """The DDP module that supports DataContainer.

    MMDDP has two main differences with PyTorch DDP:

    - It supports a custom type :class:`DataContainer` which allows more
      flexible control of input data.
    - It implement two APIs ``train_step()`` and ``val_step()``.
    """

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def train_step(self, *inputs, **kwargs):
        """train_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.train_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.train_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.train_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output

    def val_step(self, *inputs, **kwargs):
        """val_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.val_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.val_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.val_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output

② MMDataParallel

# Copyright (c) Open-MMLab. All rights reserved.
from itertools import chain

from torch.nn.parallel import DataParallel

from .scatter_gather import scatter_kwargs


class MMDataParallel(DataParallel):

    def scatter(self, inputs, kwargs, device_ids):
        """将数据分散到指定的 GPU设备"""
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def train_step(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module.train_step(*inputs, **kwargs)

        assert len(self.device_ids) == 1, \
            ('MMDataParallel only supports single GPU training, if you need to'
             ' train with multiple GPUs, please use MMDistributedDataParallel'
             'instead.')

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    'module must have its parameters and buffers '
                    f'on device {self.src_device_obj} (device_ids[0]) but '
                    f'found one of them on device: {t.device}')

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        return self.module.train_step(*inputs[0], **kwargs[0])

    def val_step(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module.val_step(*inputs, **kwargs)

        assert len(self.device_ids) == 1, \
            ('MMDataParallel only supports single GPU training, if you need to'
             ' train with multiple GPUs, please use MMDistributedDataParallel'
             'instead.')

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError(
                    'module must have its parameters and buffers '
                    f'on device {self.src_device_obj} (device_ids[0]) but '
                    f'found one of them on device: {t.device}')

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        return self.module.val_step(*inputs[0], **kwargs[0])

(3)构建优化器

构建优化器使用 build_optimizer 函数,我们可以看出它的本质也是调用 build_from_cfg。

import copy
import inspect

import torch

from ...utils import Registry, build_from_cfg

OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')


def register_torch_optimizers():
    torch_optimizers = []
    for module_name in dir(torch.optim):
        if module_name.startswith('__'):
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim,
                                                  torch.optim.Optimizer):
            OPTIMIZERS.register_module()(_optim)
            torch_optimizers.append(module_name)
    return torch_optimizers


TORCH_OPTIMIZERS = register_torch_optimizers()


def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)


def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return optimizer

(4)创建 EpochBasedRunner 并进行训练

它继承了 BaseRunner。对于 BaseRunner 主要提供了公共的属性和方法如,获取训练的属性(epoch 数量,iter 次数等)注册 hook,查看 hook 等。还有 4 个抽象方法,需要子类继承,分别是:train,val,run,save_checkpoint。

EpochBasedRunner 继承 BaseRunner,重写了 train,val,run,save_checkpoint 方法。

调用 run 方法,传入 dataloaders,work_flow,最大的循环次数等。就可以实现训练。对于不同的阶段(如:run 前,epoch 前等)调用所有相关注册的 hook。这样可定制性很强。

# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings

import torch

import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info


class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            if self.batch_processor is None:
                outputs = self.model.train_step(data_batch, self.optimizer,
                                                **kwargs)
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=True, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.train_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                if self.batch_processor is None:
                    outputs = self.model.val_step(data_batch, self.optimizer,
                                                  **kwargs)
                else:
                    outputs = self.batch_processor(
                        self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.val_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))


class Runner(EpochBasedRunner):
    """Deprecated name of EpochBasedRunner."""

    def __init__(self, *args, **kwargs):
        warnings.warn(
            'Runner was deprecated, please use EpochBasedRunner instead')
        super().__init__(*args, **kwargs)

你可能感兴趣的:(MMLab,pytorch,深度学习,神经网络)