mmdetection组件构成与注册表分析

mmdetection 使用模块化设计,将一般的目标检测算法分成了几个不同的模块,在使用时只需要在配置文件中声明各个模块使用的组件名称和参数,就可以像搭建积木一样搭建一个完整的目标检测模型;

基本组件

mmdetection的组件大多数以类的形式定义:

  • BACKBONES 对应目标检测模型的主干网络,用以对图片进行特征抽取.如常用的Resnet,ResNeXt,HRNet等.
  • NECKS 对主干网络产生的特征图做一些特定的处理,最常见的就是fpn多尺度抽取信息.现有(FPN,BFP,HRFPN等)
  • Heads 目标检测的头部,包含了目标检测的主要算法逻辑,包括bbox的产生,回归target的计算,loss的计算等
  • LOSS 损失函数的定义
  • DETECTOR 前面所介绍的组件搭建而成的一个整体,通过加载detector来运行整体算法
  • PIPELINES 数据增强管道类.定义了数据预处理和后处理部分

mmdetection中提供了类似注册表的实现方式,对各个组件进行注册和使用:
首先我们来看Registry类的定义:
mmdet/utils/registry.py

class Registry(object):
    #初始化name是什么组件,组件里面是一个dict,保存name跟它的具体类
    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    #把组件类与类名注册到注册表中,方便从config文件构建类
    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls

我们看到Registry类其实底层保存一个dict,用于保存组件名字跟具体的类.方便从注册表中找到相应的类进行初始化.接着,定义了全局注册表:
mmdet/models/registry.py

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

我们来看,注册表如何使用:
如果我们自定义了一个resnet的backbone类,我们将这样使用Registry类的register_module装饰函数,将resnet注册到BACKBONES注册表中;

@BACKBONES.register_module
class ResNet(nn.Module):

那么我们该如何从config中构建起一个类呢:

#mmdet/models/builder.py
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_backbone(cfg):
    return build(cfg, BACKBONES)

##mmdet/utils/registry.py
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    #type即注册表中类名字,代表了要从注册表中根据type的name来获得类
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    ##进行类的实例化,并传入config中的参数
    return obj_cls(**args)

build_from_cfg函数的作用是,根据config文件中的type与传入的注册表来获取需要实例化的具体类,然后再将config中的参数传入类初始化函数中,得到一个实例化的组件类.

resnet为例,整体流程如下所示:

  • (1)resnet类编写完成后,用@BACKBONES.register_module装饰器将自身注册到BACKBONES注册表中.

  • (2)在config中定义backbone,并指明了具体参数

#config/faster_rcnn_r50_fpn_1x.py
backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch')
  • (3)通过build_from_cfg()函数,传入的分别是backbone这个dict和BACKBONES注册表类
  • (4)通过'type'为ResNet找到resnet的类,并初始化参数depth,num_stages,out_indices.

mmdetection这样通过注册表的方式实现了数据与实现的分离;能更好地对组件进行抽象.

你可能感兴趣的:(mmdetection组件构成与注册表分析)