以训练过程为例,执行以下脚本。
python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth
执行代码:
# file: apis/train.py
train_model(
model, # 实例化模型类
datasets, # 实例化数据类
cfg, # 全部配置参数
distributed=distributed, # false
validate=(not args.no_validate), # true
timestamp=timestamp, # 时间戳
meta=meta) # 系统环境参数
主要执行几个步骤:
# 1、加载数据集,构建data_loaders
data_loaders = [build_dataloader(*args, **kw) for ds in dataset]
# 2、构建优化器
optimizer = build_optimizer(model, cfg.optimizer)
build_dataloader(dataset, # 实例化数据类
samples_per_gpu, # 128
workers_per_gpu, # 2
num_gpus=1, # 1
dist=True, # false
shuffle=True, # true
round_up=True, # true
seed=None, # None
**kwargs): # {}
关键代码:
batch_size = num_gpus * samples_per_gpu # 1*128=128
num_workers = num_gpus * workers_per_gpu # 1*2=2
最后调用torch.utils.data.DataLoader:
data_loader = DataLoader(
dataset, # 数据集类
batch_size=batch_size, # 128
sampler=sampler, # None
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)
optimizer = build_optimizer(model, cfg.optimizer)
model为模型结构类,cfg.optimize配置如下:
# 来自配置文件: configs/_base_/schedules/cifar10.py
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[100, 150])
total_epochs = 200
调用mmcv中的build_optimizer函数加载optimizer。
# file: mmcv/runner/optimizer/builder.py
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg) # 值为:optimizer配置的值
# constructor_type = 'DefaultOptimizerConstructor'
constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) # None
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model) # 实现SGD优化器构造
return optimizer
1、构造runner类,runner类中含运行所需所有参数。
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta) # 系统信息
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
2、注册钩子过程。
runner.register_training_hooks(
cfg.lr_config, # {'policy': 'step', 'step': [100, 150]}
optimizer_config, # {'grad_clip': None}
cfg.checkpoint_config, # {'interval': 1, 'meta': {'mmcls_version': '0.1.0+dae1c86', 'config': "model = dict(\n type='ImageClassifier',\n backbone=dict(\n type='ResNet_CIFAR',\n depth=50,\n num_stages=4,\n out_indices=(3, ),\n style='pytorch'),\n neck=dict(type='GlobalAveragePooling'),\n head=dict(\n type='LinearClsHead',\n num_classes=10,\n in_channels=2048,\n loss=dict(type='CrossEntropyLoss', loss_weight=1.0)))\ndataset_type = 'CIFAR10'\nimg_norm_cfg = dict(\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True)\ntrain_pipeline = [\n dict(type='RandomCrop', size=32, padding=4),\n dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n dict(\n type='Normalize',\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='ToTensor', keys=['gt_label']),\n dict(type='Collect', keys=['img', 'gt_label'])\n]\ntest_pipeline = [\n dict(\n type='Normalize',\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='ToTensor', keys=['gt_label']),\n dict(type='Collect', keys=['img', 'gt_label'])\n]\ndata = dict(\n samples_per_gpu=128,\n workers_per_gpu=2,\n train=dict(\n type='CIFAR10',\n data_prefix='../data/cifar10',\n pipeline=[\n dict(type='RandomCrop', size=32, padding=4),\n dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n dict(\n type='Normalize',\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='ToTensor', keys=['gt_label']),\n dict(type='Collect', keys=['img', 'gt_label'])\n ]),\n val=dict(\n type='CIFAR10',\n data_prefix='../data/cifar10',\n pipeline=[\n dict(\n type='Normalize',\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='ToTensor', keys=['gt_label']),\n dict(type='Collect', keys=['img', 'gt_label'])\n ]),\n test=dict(\n type='CIFAR10',\n data_prefix='../data/cifar10',\n pipeline=[\n dict(\n type='Normalize',\n mean=[125.307, 122.961, 113.8575],\n std=[51.5865, 50.847, 51.255],\n to_rgb=True),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='ToTensor', keys=['gt_label']),\n dict(type='Collect', keys=['img', 'gt_label'])\n ]))\noptimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=None)\nlr_config = dict(policy='step', step=[100, 150])\ntotal_epochs = 200\ncheckpoint_config = dict(interval=1)\nlog_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = '../work_dirs/resnet50/epoch_27.pth'\nworkflow = [('train', 1)]\nwork_dir = './work_dirs\\resnet50'\ngpu_ids = range(0, 1)\nseed = None\n", 'CLASSES': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']}}
cfg.log_config, # {'interval': 100,
'hooks': [{'type': 'TextLoggerHook'}]
}
cfg.get('momentum_config', None) # None
)
按照以下顺序注册:
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_logger_hooks(log_config)
3、运行训练过程
在循环过程中,依据钩子过程执行训练。
def run(self, data_loaders, workflow, max_epochs, **kwargs):
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
传送门:mmclassification项目阅读系列文章目录
源码阅读:
1、setup.py工程环境配置(一)
2、mmcls库组织结构说明(二)
3、registry类注册机制(三)
4、模型加载过程(四)
5、数据加载过程(五)
6、train_model执行过程(六)