mmclassification源码阅读(五) 数据加载过程

以训练过程为例,执行以下脚本。

python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth

1、整体流程

执行代码:

datasets = [build_dataset(cfg.data.train)]

其中cfg.data.train为:

# file: configs/_base_/datasets/cifar10.py
train_pipeline = [
    dict(type='RandomCrop', size=32, padding=4),
    ...
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]

data = dict(
    ...
    train=dict(
        type=dataset_type,  # 'CIFAR10'
        data_prefix='data/cifar10',
        pipeline=train_pipeline),

build_dataset(cfg.data.train)函数传参DATASETS注册器,最后调用mmcv.utils.registry.py中的build_from_cfg(cfg, registry, default_args)函数。最终加载mmcls/datasets/cifar.py中CIFAR10类,并实例化。

2、CIFAR10类实例化

CIFAR10类实现标记、数据管理,图像数据增广pipeline结构。继承自BaseDataset,CIFAR10类本身没有提供__init__函数,CIFAR10类初始化时,会默认调用基类的__init__构造函数。基类函数接收关键字参数data_prefix、pipeline,ann_file、test_mode采用默认值。

class CIFAR10(BaseDataset):
    pass

class BaseDataset(Dataset, metaclass=ABCMeta):
     CLASSES = None

    def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
        super(BaseDataset, self).__init__()

        self.ann_file = ann_file
        self.data_prefix = data_prefix  # 'data/cifar10'
        self.test_mode = test_mode
        self.pipeline = Compose(pipeline)  # train_pipeline
        self.data_infos = self.load_annotations()    

self.pipeline = Compose(pipeline) 执行中,调用PIPELINES注册器,依次加载配置的对应type类型的处理类。

传送门:mmclassification项目阅读系列文章目录

源码阅读:

1、setup.py工程环境配置(一)

2、mmcls库组织结构说明(二)

3、registry类注册机制(三)

4、模型加载过程(四)

5、数据加载过程(五)

6、train_model执行过程(六)

你可能感兴趣的:(4.1,python,人工智能,pytorch,mmcls,数据记载)