上一篇博客对MMDetection中的配置文件进行了介绍,其中提到,我们在配置文件中配置到模型、数据集、训练策略等后,通过Config类可以将配置文件中的参数信息以字典的形式进行管理,然后MMDetection框架就会对其自动进行解析,帮助我们构建整个算法流程。MMDetection使用注册机制来实现从配置参数到算法模块的构建。 本篇博客将从源码出发,对MMCV中的注册机制进行详细介绍。
- 官方文档 - MMCV
- 官方知乎 - MMCV 核心组件分析(五): Registry
注册机制是MMCV中非常重要的一个概念,在MMDetection中如果你想要增加自己的算法模块或流程,都需要通过注册机制来实现。
介绍注册机制之前先介绍一下Registry类。
MMCV使用注册器(Registry)来管理具有相似功能的不同模块,比如ResNet、FPN、RoIHead都属于模型结构,SGD、Adam都属于优化器。注册器内部其实是在维护一个全局的查询表,key是字符串,value是类。
简单来说,注册器可以看做字符串到类(Class)的映射。借助注册器,用户可以通过字符串查询到对应的类,并实例化该类。有了这个认知后,我再看Registry类的源码就很容易理解了,先看下构造函数,其功能主要是初始化注册器的名字、实例化函数,并初始化一张字典类型的查询表_module_dict
:
from mmcv.utils import Registry
class Registry:
# 构造函数
def __init__(self, name, build_func=None, parent=None, scope=None):
"""
name (str): 注册器的名字
build_func(func): 从注册器构建实例的函数句柄
parent (Registry): 父类注册器
scope (str): 注册器的域名
"""
self._name = name
# 使用module_dict管理字符串到类的映射
self._module_dict = dict()
self._children = dict()
# 如果scope未指定, 默认使用类定义位置所在的包名, 比如mmdet, mmseg
self._scope = self.infer_scope() if scope is None else scope
# build_func按照如下优先级初始化:
# 1. build_func: 优先使用指定的函数
# 2. parent.build_func: 其次使用父类的build_func
# 3. build_from_cfg: 默认从config dict中实例化对象
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
# 设置父类-子类的从属关系
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
比如说,我们现在想要使用注册器来管理我们的模型,首先初始化一个Registry实例MODELS
,然后调用Registry类的register_module()
方法完成ResNet和VGG类的注册,可以看到最后MODELS
的打印结果中包含了这两个类的信息(打印信息中items对应的其实就是self._module_dict
),表示注册成功。为了代码简洁,这里推荐使用python的函数装饰器@
实现register_module()
的调用。然后就可以通过build()
函数来实例化我们的模型了。
# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')
# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
def __init__(self, depth):
self.depth = depth
print('Initialize ResNet{}'.format(depth))
# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)
class FPN(object):
def __init__(self, in_channel):
self.in_channel= in_channel
print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)
print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': , 'FPN': })
"""
# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""
在实例化一个Registry对象后,类的注册和实例化分别通过register_module
和build
函数完成的,下面来看看这两个函数的源码。
register_module()
内部实际上调用的是self._register_module()
函数,功能也很简单,就是将当前要注册的模块名称和模块类型以键值对key->value的形式保存到_module_dict
查询表中。
def _register_module(self, module_class, module_name=None, force=False):
"""
module_class (class): 要注册的模块类型
module_name (str): 要注册的模块名称
force (bool): 是否强制注册
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
# 如果未指定模块名称则使用默认名称
if module_name is None:
module_name = module_class.__name__
# module_name为list形式, 从而支持在nn.Sequentail中构建pytorch模块
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
# 如果force=False, 则不允许注册相同名称的模块
# 如果force=True, 则用后一次的注册覆盖前一次
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
# 将当前注册的模块加入到查询表中
self._module_dict[name] = module_class
build
函数是指向build_func()
函数的(见Registry的构造函数),可以在模块注册的时候由用户手动指定,但由于模块一般都是用函数装饰器的方式来注册,所以build_func()
实际上调用的都是build_from_cfg()
函数。build_from_cfg()
根据配置参数中的type值找到对应的模块类型obj_cls
,然后使用cfg和default_args中的参数实例化对应的模块,并返回实例化对象给上级的build()
函数调用。
def build_from_cfg(cfg, registry, default_args=None):
"""
cfg (dict): 配置参数信息
registry (Registry): 注册器
"""
# cfg类型校验, 必须为字典类型
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
# cfg中必须要有type字段
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
# registry类型校验, 必须为Registry类型
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
# default_args以字典的形式传入
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
# 将cfg以外的外部传入参数也加入到args中
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# 获取模块名称
obj_type = args.pop('type')
if isinstance(obj_type, str):
# 根据模块名称获取到模块类型
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
# type值是模块本身
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
MMCV使用注册器来管理具有相似功能的不同模块,一个注册器内部会维护一个查询表,使用该注册器注册的模块都会以键值对的形式保存在这个查询表中,注册器还提供实例化方法,根据模块名称返回对应的实例化对象。
MMDetection内部已经构建了许多常用的注册器,并实现了对应的接口函数,比如DETECTORS对应build_detector()
,DATASETS对应build_dataset()
,无论是什么样的xxx_build(),最终都是调用Registry.build()
函数。我们绝大多数时候只需要使用现成的注册器即可。
# MMDetection中的Registry
MODELS = Registry('models', parent=MMCV_MODELS) # 从MMCV继承得到
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')