mmclassification源码阅读(六) train_model执行过程

以训练过程为例,执行以下脚本。

python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth

1、整体流程

执行代码:

# file: apis/train.py
train_model(
    model,  # 实例化模型类
    datasets,  # 实例化数据类
    cfg,  # 全部配置参数
    distributed=distributed,  # false 
    validate=(not args.no_validate),  # true
    timestamp=timestamp,  # 时间戳
    meta=meta)  # 系统环境参数

主要执行几个步骤:

# 1、加载数据集,构建data_loaders
data_loaders = [build_dataloader(*args, **kw) for ds in dataset]

# 2、构建优化器
optimizer = build_optimizer(model, cfg.optimizer)

2、build_dataloader构造过程

build_dataloader(dataset,  # 实例化数据类
                 samples_per_gpu,  # 128
                 workers_per_gpu,  # 2
                 num_gpus=1,  # 1
                 dist=True,  # false
                 shuffle=True,  # true
                 round_up=True,  # true
                 seed=None,  # None
                 **kwargs):  # {}

关键代码:

batch_size = num_gpus * samples_per_gpu  # 1*128=128
num_workers = num_gpus * workers_per_gpu  # 1*2=2

最后调用torch.utils.data.DataLoader:

data_loader = DataLoader(
    dataset, # 数据集类
    batch_size=batch_size, # 128
    sampler=sampler, # None
    num_workers=num_workers,
    collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
    pin_memory=False,
    shuffle=shuffle,
    worker_init_fn=init_fn,
    **kwargs)

3、build_optimizer构造过程

optimizer = build_optimizer(model, cfg.optimizer)

model为模型结构类,cfg.optimize配置如下:

# 来自配置文件: configs/_base_/schedules/cifar10.py
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[100, 150])
total_epochs = 200

调用mmcv中的build_optimizer函数加载optimizer。

# file: mmcv/runner/optimizer/builder.py
def build_optimizer(model, cfg): 
    optimizer_cfg = copy.deepcopy(cfg) # 值为:optimizer配置的值
    # constructor_type = 'DefaultOptimizerConstructor'
    constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor')    
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) # None
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model) # 实现SGD优化器构造
    return optimizer

4、runner构造过程

1、构造runner类,runner类中含运行所需所有参数。

runner = EpochBasedRunner(
    model,
    optimizer=optimizer,
    work_dir=cfg.work_dir,
    logger=logger,
    meta=meta)  # 系统信息
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp

2、注册钩子过程。

runner.register_training_hooks(
    cfg.lr_config, # {'policy': 'step', 'step': [100, 150]}
    optimizer_config, # {'grad_clip': None}
    cfg.checkpoint_config, # {'interval': 1, 'meta': {'mmcls_version': '0.1.0+dae1c86', 'config': "model = dict(\n    type='ImageClassifier',\n    backbone=dict(\n        type='ResNet_CIFAR',\n        depth=50,\n        num_stages=4,\n        out_indices=(3, ),\n        style='pytorch'),\n    neck=dict(type='GlobalAveragePooling'),\n    head=dict(\n        type='LinearClsHead',\n        num_classes=10,\n        in_channels=2048,\n        loss=dict(type='CrossEntropyLoss', loss_weight=1.0)))\ndataset_type = 'CIFAR10'\nimg_norm_cfg = dict(\n    mean=[125.307, 122.961, 113.8575],\n    std=[51.5865, 50.847, 51.255],\n    to_rgb=True)\ntrain_pipeline = [\n    dict(type='RandomCrop', size=32, padding=4),\n    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n    dict(\n        type='Normalize',\n        mean=[125.307, 122.961, 113.8575],\n        std=[51.5865, 50.847, 51.255],\n        to_rgb=True),\n    dict(type='ImageToTensor', keys=['img']),\n    dict(type='ToTensor', keys=['gt_label']),\n    dict(type='Collect', keys=['img', 'gt_label'])\n]\ntest_pipeline = [\n    dict(\n        type='Normalize',\n        mean=[125.307, 122.961, 113.8575],\n        std=[51.5865, 50.847, 51.255],\n        to_rgb=True),\n    dict(type='ImageToTensor', keys=['img']),\n    dict(type='ToTensor', keys=['gt_label']),\n    dict(type='Collect', keys=['img', 'gt_label'])\n]\ndata = dict(\n    samples_per_gpu=128,\n    workers_per_gpu=2,\n    train=dict(\n        type='CIFAR10',\n        data_prefix='../data/cifar10',\n        pipeline=[\n            dict(type='RandomCrop', size=32, padding=4),\n            dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),\n            dict(\n                type='Normalize',\n                mean=[125.307, 122.961, 113.8575],\n                std=[51.5865, 50.847, 51.255],\n                to_rgb=True),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='ToTensor', keys=['gt_label']),\n            dict(type='Collect', keys=['img', 'gt_label'])\n        ]),\n    val=dict(\n        type='CIFAR10',\n        data_prefix='../data/cifar10',\n        pipeline=[\n            dict(\n                type='Normalize',\n                mean=[125.307, 122.961, 113.8575],\n                std=[51.5865, 50.847, 51.255],\n                to_rgb=True),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='ToTensor', keys=['gt_label']),\n            dict(type='Collect', keys=['img', 'gt_label'])\n        ]),\n    test=dict(\n        type='CIFAR10',\n        data_prefix='../data/cifar10',\n        pipeline=[\n            dict(\n                type='Normalize',\n                mean=[125.307, 122.961, 113.8575],\n                std=[51.5865, 50.847, 51.255],\n                to_rgb=True),\n            dict(type='ImageToTensor', keys=['img']),\n            dict(type='ToTensor', keys=['gt_label']),\n            dict(type='Collect', keys=['img', 'gt_label'])\n        ]))\noptimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=None)\nlr_config = dict(policy='step', step=[100, 150])\ntotal_epochs = 200\ncheckpoint_config = dict(interval=1)\nlog_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nload_from = None\nresume_from = '../work_dirs/resnet50/epoch_27.pth'\nworkflow = [('train', 1)]\nwork_dir = './work_dirs\\resnet50'\ngpu_ids = range(0, 1)\nseed = None\n", 'CLASSES': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']}}
    cfg.log_config, # {'interval': 100, 
                        'hooks': [{'type': 'TextLoggerHook'}]
                      }
    cfg.get('momentum_config', None) # None
)

按照以下顺序注册:

self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_logger_hooks(log_config)

3、运行训练过程

在循环过程中,依据钩子过程执行训练。

def run(self, data_loaders, workflow, max_epochs, **kwargs):
    self._max_epochs = max_epochs
    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    work_dir = self.work_dir if self.work_dir is not None else 'NONE'
    self.logger.info('Start running, host: %s, work_dir: %s',
                     get_host_info(), work_dir)
    self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
    self.call_hook('before_run')

    while self.epoch < max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if isinstance(mode, str):  # self.train()
                if not hasattr(self, mode):
                    raise ValueError(
                        f'runner has no method named "{mode}" to run an '
                        'epoch')
                epoch_runner = getattr(self, mode)
            else:
                raise TypeError(
                    'mode in workflow must be a str, but got {}'.format(
                        type(mode)))

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= max_epochs:
                    break
                epoch_runner(data_loaders[i], **kwargs)

    time.sleep(1)  # wait for some hooks like loggers to finish
    self.call_hook('after_run')

传送门:mmclassification项目阅读系列文章目录

源码阅读:

1、setup.py工程环境配置(一)

2、mmcls库组织结构说明(二)

3、registry类注册机制(三)

4、模型加载过程(四)

5、数据加载过程(五)

6、train_model执行过程(六)

 

你可能感兴趣的:(4.1,pytorch,深度学习,mmcls,训练)