mmclassification源码阅读(四) 模型加载过程

以训练过程为例,执行以下脚本。

python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth

1、整体流程

首先加载配置,args为用户输入参数,cfg为配置文件配置参数。只有将参数统一合并至cfg管理。

args = parse_args()
cfg = Config.fromfile(args.config)
... # 参数合并,预处理过程

2、cfg参数解析

经过mmcv.Config解析后的变量以类对象形式存在,变量保存在保护成员变量_cfg_dict中,访问时直接采用cfg.key形式访问。其中:

1、args参数,放在最外层。如:

args = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}

解析至cfg后为:

cfg = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}

2、配置文件解析,configs/cifar10/resnet50.py解析过程:

# file: configs/cifar10/resnet50.py
_base_ = [
    '../_base_/models/resnet50_cifar.py',  # 1、模型结构配置
    '../_base_/datasets/cifar10.py',  # 2、数据集配置
    '../_base_/schedules/cifar10.py',   # 3、训练参数配置
        '../_base_/default_runtime.py'  # 4、运行时,模型保存,显卡等配置
]

解析__base__内文件配置,对文件内容依次展开,合并放至cfg最外层。如运行时default_runtime.py文件参数为:

# file: ../_base_/models/resnet50_cifar.py,  model settings
model = dict(  # 1字典
    type='ImageClassifier',  # 1类型 -注册器:CLASSIFIERS 
    backbone=dict(  # 1.1字典
        type='ResNet_CIFAR',  # 1.1类型  -注册器:BACKBONES 
        depth=50,  # **kwargs
        num_stages=4,# **kwargs
        out_indices=(3, ),# **kwargs
        style='pytorch'),# **kwargs
    neck=dict(  # 1.2字典
        type='GlobalAveragePooling'), # 1.2类型  -注册器:NECKS 
    head=dict(  # 1.3字典
        type='LinearClsHead', # 1.3类型  -注册器:HEADS 
        num_classes=10,  # **kwargs
        in_channels=2048,  # **kwargs
        loss=dict(  # 2..1字典
            type='CrossEntropyLoss', # 2.1类型  -注册器:LOSSES 
            loss_weight=1.0),  # **kwargs
        )
)

解析至cfg后为:

cfg = {..., 'model': xx, ...}

3、构建模型结构

执行代码:

model = build_classifier(cfg.model)

model字典为两级结构,第一层为任务类型type和四个基本结构(backbone、neck、head、loss),第二层为各个结构配置参数,字典型,包含type和额外参数**kwargs。每一个type都对应一种注册器,按层级结构递归调用注册器进行模型结构生成。注册器在mmcls/models/builder.py中创建,models下对应的包中进行注册。

# file: mmcls/models/builder.py
BACKBONES = Registry('backbone')
CLASSIFIERS = Registry('classifier')
HEADS = Registry('head')
NECKS = Registry('neck')
LOSSES = Registry('loss')

step1、build_classifier函数调用

build_classifier函数传参CLASSIFIERS注册器,最后调用mmcv.utils.registry.py中的build_from_cfg(cfg, registry, default_args)函数。

def build_classifier(cfg):
    return build(cfg, CLASSIFIERS)
    
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)    

build_from_cfg函数实现:

# file: mmcv.utils.registry.py
def build_from_cfg(cfg, registry, default_args=None):
     ...     
    args = cfg.copy()
    obj_type = args.pop('type')  # 1、获取dict中type字段(注册器名称)'ImageClassifier'
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)  # 2、获取注册器内对应的类
        ...
 return obj_cls(**args)  # 3、实例化类对象,type以外的参数作为关键字参数传递于类初始化-'ImageClassifier'类

实际加载过程:

调用CLASSIFIERS('classifier')注册器,加载其中的ImageClassifier类,type以外的参数作为关键字参数,实例化ImageClassifier类。

step2、ImageClassifier类实例化

ImageClassifier构造函数如下,接受关键字参数backbone、neck、head,loss未声明此忽略, pretrained=None。继承自BaseClassifier,首先调用基类BaseClassifier构造函数。之后依次加载self.backbone、self.neck、self.head、及权重初始化。

@CLASSIFIERS.register_module()
class ImageClassifier(BaseClassifier):

    def __init__(self, backbone, neck=None, head=None, pretrained=None):
        super(ImageClassifier, self).__init__()
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if head is not None:
            self.head = build_head(head)

        self.init_weights(pretrained=pretrained)

step3、self.backbone、self.neck、self.head类实例化

backbone构造,调用BACKBONES('backbone')注册器,加载其中ResNet_CIFAR类,type以外的参数作为关键字参数,实例化ResNet_CIFAR类。(继承自基类ResNet,实际上基类中实现初始化,子类中覆盖_make_stem_layer、forward函数)。

neck构造,调用NECKS('backneckbone')注册器,加载其中'GlobalAveragePooling'类,type以外的参数作为关键字参数,实例化GlobalAveragePooling类。

head构造,调用HEADS('head')注册器,加载其中'LinearClsHead'类,type以外的参数作为关键字参数,实例化LinearClsHead类。LinearClsHead类接收loss关键字参数。

# file: mmcls/models/heads/cls_head.py
@HEADS.register_module()
class ClsHead(BaseHead):
    def __init__(self,
                 loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
                 topk=(1, )):
        super(ClsHead, self).__init__()
        ...
        
        self.compute_loss = build_loss(loss)
        self.compute_accuracy = Accuracy(topk=self.topk)

loss构造,调用LOSSES('loss')注册器,加载其中'CrossEntropyLoss'类,type以外的参数作为关键字参数,实例化CrossEntropyLoss类。loss类实例化对象存在与self.head对象中。

传送门:mmclassification项目阅读系列文章目录

源码阅读:

1、setup.py工程环境配置(一)

2、mmcls库组织结构说明(二)

3、registry类注册机制(三)

4、模型加载过程(四)

5、数据加载过程(五)

6、train_model执行过程(六)

你可能感兴趣的:(4.1,python,pytorch,深度学习,mmcls,模型加载)