目录
(一)构建 data loaders(mmdet/datasets/builder.py)
(2)构建分布式处理对象
(3)构建优化器
(4)创建 EpochBasedRunner 并进行训练
其主要步骤是创建采样器,并将采样器,collate 函数, worker_init_fn 函数传入 DataLoader 中,用于创建 pytorch dataloader。
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
**kwargs):
# 获取进程编号和总进程数
rank, world_size = get_dist_info()
# 如果是分布式训练, 即使用 dist_train.sh 会进入此 if.
if dist:
# DistributedGroupSampler 会进行 shuffle, 而且会保证每个 GPU 的样本都是同一组的.
if shuffle:
sampler = DistributedGroupSampler(dataset, samples_per_gpu,
world_size, rank)
# 不 shuffle, 使用 torch.utils.data 中的 DistributedSampler
# 因为 pytorch < 1.2 没有 shuffle 形参,
# 为了版本适配, 重写了一个 DistributedSampler
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
batch_size = samples_per_gpu
num_workers = workers_per_gpu
# 不是分布式训练, 即直接使用 train.py 进行训练.
else:
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
采样器总共有三种:DistributedGroupSampler,DistributedSampler,GroupSampler。
DistributedGroupSampler 和 GroupSampler 用于按组进行分布式采样(注意:在 MMDetection 中将图像宽大于高和宽小于高分为两组,每个 GPU 中的图像应该取自同一组)
DistributedSampler 是 pytorch 自带分布式采样的重写类,因为 pytorch < 1.2 没有 shuffle 参数。也就是说,pytorch < 1.2 的 DistributedSampler 只支持顺序分布式采样。而 pytorch >= 1.2 支持顺序和乱序采样。为了版本适配和接口统一,重写了一个 DistributedSampler,将 shuffle 始终设置为 False。
① GroupSampler(datasets/samplers/group_samples.py)
from __future__ import division
import math
import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
# 如果图片的 宽 > 高, 记为 为 1
# 宽 < 高, 记为 为 0
# flag 是一个记录了数据集中所有图片的 ndarray
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64)
# np.bincount 计算每个索引出现的次数
# 在这里就相当于计算了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
# 保证每组的 sample 数都能被 samples_per_gpu 的数量整除
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes):
# 如果数据集中的所有的图片的宽都 < 高, 那么进行下一次循环.
if size == 0:
continue
# 找到 宽 < 高(i = 0) 或 宽 > 高(i = 1) 的所有的图片索引
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# 随机打乱索引
np.random.shuffle(indice)
# 因为图片个数不一定会被 samples_per_gpu 整除, 所以添加额外的数据.
# num_extra 即为添加额外数据的数量.
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
# np.concatenate(需要concat的list, axis=0)
# np.random.choice(list, 选的size)
# 生成所有的 index
indice = np.concatenate(
[indice, np.random.choice(indice, num_extra)])
indices.append(indice)
# 整合所有的 index
indices = np.concatenate(indices)
# 如下操作可以保证每个 samples_per_gpu 的 flag 都相同
indices = [
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
② DistributedGroupSampler(mmset/datasets/samplers/group_samples.py)
class DistributedGroupSampler(Sampler):
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
# 获取 rank 和 world_size (num_replicas)
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
# 统计了有多少个宽 > 高的图片, 和有多少个宽 < 高的图片
self.group_sizes = np.bincount(self.flag)
# 每个进程需要采样的样本数
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
# self.group_sizes[i] / self.samples_per_gpu:能分成多少组
# 下面的代表计算了每个机器分的个数.
self.num_samples += int(
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
self.num_replicas)) * self.samples_per_gpu
# 所有进程要采样的样本总数。
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# 把当前的 epoch 作为随机数种子,
# 这样能保证在相同的 epoch 的实验有可重复性,
# 且在不同的 epoch 之间有随机性.
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
for i, size in enumerate(self.group_sizes):
# 如果有样本
if size > 0:
# 找出所有属于此类的索引
indice = np.where(self.flag == i)[0]
assert len(indice) == size
# 随机打乱索引
indice = indice[list(torch.randperm(int(size),
generator=g))].tolist()
# 总共需要额外添加的样本数
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas)
) * self.samples_per_gpu * self.num_replicas - len(indice)
# 填充 indice
tmp = indice.copy()
for _ in range(extra // size):
indice.extend(tmp)
# 取随机后的前 extra 个作为 extra 样本.
indice.extend(tmp[:extra % size])
indices.extend(indice)
assert len(indices) == self.total_size
# 打乱 sample_per_gpu 之间的顺序,
# 因为上面已经打乱了每个 group 之内的元素,
# 所以这里只用打乱组之间的顺序即可.
indices = [
indices[j] for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g))
for j in range(i * self.samples_per_gpu, (i + 1) *
self.samples_per_gpu)
]
# 采样 num_samples 个.不同进程之间按照打乱的数据集顺序采样.
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
③ DistributedSampler(datasets/samplers/distributed_sampler.py)
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
# pytorch < 1.2 没有 shuffle, 为了版本适配, 这里选择重写
class DistributedSampler(_DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
def __iter__(self):
# 把当前的 epoch 作为随机数种子,
# 这样能保证在相同的 epoch 的实验有可重复性,
# 且在不同的 epoch 之间有随机性.
if self.shuffle:
# 使用随机数生成器, 根据 epoch 生成随机数种子.
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# 添加额外的样本使其均匀可分
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
collate 函数用于整理与合并每个 batch 的数据。并将这个函数传给 DataLoader 的 collate_fn 参数。
from collections.abc import Mapping, Sequence
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
from .data_container import DataContainer
def collate(batch, samples_per_gpu=1):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
Extend default_collate to add support for
:type:`~mmcv.parallel.DataContainer`. There are 3 cases.
1. cpu_only = True, e.g., meta data
2. cpu_only = False, stack = True, e.g., images tensors
3. cpu_only = False, stack = False, e.g., gt bboxes
"""
# batch 是一个长度为 batch_size 的列表, 每个元素是一个字典, 每个字典代表一张图片.
# 字典的键为:dict_keys(['img_metas', 'img', 'gt_bboxes', 'gt_labels'])
# 确保 batch 是一个序列
if not isinstance(batch, Sequence):
raise TypeError(f'{batch.dtype} is not supported.')
if isinstance(batch[0], DataContainer):
assert len(batch) % samples_per_gpu == 0
stacked = []
# cpu_only 说明是 meta data
if batch[0].cpu_only:
# batch[0].stack: False
# batch[0].padding_value: 0
for i in range(0, len(batch), samples_per_gpu):
# 每 samples_per_gpu 个, 创建一个列表
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
# 转成 DataContainer 对象
return DataContainer(
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
# stack 为 True 说明是图片类型的数据 或 label 数据
elif batch[0].stack:
for i in range(0, len(batch), samples_per_gpu):
assert isinstance(batch[i].data, torch.Tensor)
# 需要填充维度
if batch[i].pad_dims is not None:
ndim = batch[i].dim()
assert ndim > batch[i].pad_dims
max_shape = [0 for _ in range(batch[i].pad_dims)]
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = batch[i].size(-dim)
for sample in batch[i:i + samples_per_gpu]:
for dim in range(0, ndim - batch[i].pad_dims):
assert batch[i].size(dim) == sample.size(dim)
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = max(max_shape[dim - 1],
sample.size(-dim))
padded_samples = []
for sample in batch[i:i + samples_per_gpu]:
pad = [0 for _ in range(batch[i].pad_dims * 2)]
for dim in range(1, batch[i].pad_dims + 1):
pad[2 * dim -
1] = max_shape[dim - 1] - sample.size(-dim)
padded_samples.append(
F.pad(
sample.data, pad, value=sample.padding_value))
stacked.append(default_collate(padded_samples))
# 不填充维度
elif batch[i].pad_dims is None:
stacked.append(
default_collate([
sample.data
for sample in batch[i:i + samples_per_gpu]
]))
else:
raise ValueError(
'pad_dims should be either None or integers (1-3)')
# 说明是 gt bboxes
else:
# 取 samples_per_gpu 个, 创建列表返回.
for i in range(0, len(batch), samples_per_gpu):
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
# 是序列
elif isinstance(batch[0], Sequence):
transposed = zip(*batch)
return [collate(samples, samples_per_gpu) for samples in transposed]
# 最开始传入的是一个字典, 里面有 图像属性, 图像, gt, label 等信息.
# 所以会先进入下面的 if
elif isinstance(batch[0], Mapping):
# 返回一个字典: 每个 key 的值是原来所有 key 的值的 collate 后的结果
return {
key: collate([d[key] for d in batch], samples_per_gpu)
# 遍历每个 key
for key in batch[0]
}
# 采用默认的整理方式
else:
return default_collate(batch)
每个线程的随机数种子默认为线程ID,每次运行时随机数种子不固定。考虑到实验的可重复性,创建一个 worker_init_fn 函数传参给 DataLoader 中的 worker_init_fn 参数,此参数是 worker 的初始化函数。将 num_worker * rank + worker_id + user_seed 作为随机数种子,可以解决每个线程中随机数种子不确定的情况。
def worker_init_fn(worker_id, num_workers, rank, seed):
# 将 num_worker * rank + worker_id + user_seed 作为随机数种子
# 可以解决线程之间随机数种子不固定的情况
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
MMDetection 对 pytorch 的 DistributedDataParallel 和 DataParallel 在外面有一层封装,重写了 scatter 方法,额外实现了 train_step 和 val_step 方法。scatter 方法用于将数据分发到指定的 GPU,train_step 和 val_step 对于传入的一个 batch 的数据,会调用 Detector 的 train_step 或 val_step 计算损失或得到模型输出值。(MMDetection 所有 Detector 都有 train_step 和 val_step 方法,这样在训练的时候就不需要传入损失函数来计算损失了,不同的模型可以使用不同的损失,同一个模型也可以使用不同的损失函数。这样更灵活)
如果单机多卡会使用 MMDistributedDataParallel 构建对象。如果单机单卡会使用 MMDataParallel 构建对象。
① MMDistributedDataParallel
# Copyright (c) Open-MMLab. All rights reserved.
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False
return output
② MMDataParallel
# Copyright (c) Open-MMLab. All rights reserved.
from itertools import chain
from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs
class MMDataParallel(DataParallel):
def scatter(self, inputs, kwargs, device_ids):
"""将数据分散到指定的 GPU设备"""
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
return self.module.train_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.train_step(*inputs[0], **kwargs[0])
def val_step(self, *inputs, **kwargs):
if not self.device_ids:
return self.module.val_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.val_step(*inputs[0], **kwargs[0])
构建优化器使用 build_optimizer 函数,我们可以看出它的本质也是调用 build_from_cfg。
import copy
import inspect
import torch
from ...utils import Registry, build_from_cfg
OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')
def register_torch_optimizers():
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module()(_optim)
torch_optimizers.append(module_name)
return torch_optimizers
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
它继承了 BaseRunner。对于 BaseRunner 主要提供了公共的属性和方法如,获取训练的属性(epoch 数量,iter 次数等)注册 hook,查看 hook 等。还有 4 个抽象方法,需要子类继承,分别是:train,val,run,save_checkpoint。
EpochBasedRunner 继承 BaseRunner,重写了 train,val,run,save_checkpoint 方法。
调用 run 方法,传入 dataloaders,work_flow,最大的循环次数等。就可以实现训练。对于不同的阶段(如:run 前,epoch 前等)调用所有相关注册的 hook。这样可定制性很强。
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings
import torch
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead')
super().__init__(*args, **kwargs)