理解MMLAB开源代码就从这开始:注册器Registry

之前写过对Registry,Runner,Hook的介绍,但是Registry这个有读者反映没特别理解,我重新回顾了下发现之前的文章写得有点冗长,加上自己后续还有一些其他系列干货会分享出来,因此今天就特地重新整理下,写一篇新的文章出来,力保大家看了醍醐灌顶。

工作日趋繁忙,上一篇讲HOOK的文章获得了大家的赞同受宠若惊,虽然才疏学浅,今天就接着上次的内容,讲一下MMDetection中注册器Registry,希望能帮助到大家理解这个框架。

首先,要明确注册器的使用目的就是为了在算法训练、调参中通过直接更改配置文件。因为注册器其实就是一个字典,完成了字符串->类的映射,它是一个算法训练框架实现模块化手段,不会在代码中引入hardcode。比如在MMDetection中有下面这些定义的组件。

from mmcv.utils import Registry, build_from_cfg
from torch import nn

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

其中,模型被拆解成了BackbonesNecksRoi_extractorsShared_HeadsHeadsLossesDetectors几个部分,当时如果我们需要添加更多的模块也是可以的,只需要实例化我们各自的注册器就行。

mmcv中的registry代码很长,因为开源代码需要做各种类型检查,还需要提供很多种接口。我自己手写了一个最迷你的注册器,并且用实际的代码来带大家看下注册器是怎么做成组件化的。

class Registry:
    def __init__(self, name):
        self.name = name
        self.module_dict = {} # 记录一个名字->类的字典

    def _register_module(self, module_class):
        module_name = module_class.__name__
        self.module_dict[module_name] = module_class

    def register_module(self):
        def _register(cls):
            self._register_module(module_class=cls)
            return cls
        return _register

有个注册器这个类,那么我们就可以将其组件化。比如我组件化出BackboneHead这两个东西。就可以这么做:

# 定义一个根据参数实例化对象的方法,这含义就是说配置文件里的Type字段需要和你实际的类名对应上。然后实例化出一个对象。
def build_from_cfg(cfg, registry):
    obj_type = cfg.pop('type')
    obj_cls = registry.module_dict[obj_type]
    return obj_cls(**cfg)

# 生成BACKBONE和HEAD组件
BACKBONE = Registry('backbone')
HEAD = Registry('head')

def build_backbone(cfg):
    return build_from_cfg(cfg, BACKBONE)

def build_head(cfg):
    return build_from_cfg(cfg, HEAD)

通过上述代码,我们就有了两个组件,backbonehead,并且有了对应的build方法去实例化相应的对象。比如我们新增backbonehead的时候,可以这么操作。

@BACKBONE.register_module()
class ResNet(object):
    def __init__(self, new_name):
        self.new_name = new_name

@HEAD.register_module()
class NewHead(object):
    def __init__(self, new_name):
        self.new_name = new_name

通过注册器来实现组件化在MMLAB代码里是最常见的,当然原始的代码不仅仅是实现了注册器的方式,感兴趣的同学可以去琢磨下源码,但是到这一步基本上其精髓已经写出来了。可以看下这个组件里是啥东西

for k, v in BACKBONE.module_dict.items():
    print(k, v)
    # 输出:ResNet 
    
for k, v in HEAD.module_dict.items():
    print(k, v)
    # 输出:NewHead 

可以看到注册器中已经有了这个映射了,那么我们就可以实现通过配置文件来构件模型。

from utils.registry import build_backbone, build_head

cfg = dict(
    backbone=dict(type='ResNet', new_name='test_backbone'),
    head=dict(type='NewHead', new_name='test_newhead')
)

class SampleModel():
    def __init__(self, backbone, head):
        self.backbone = build_backbone(backbone)
        self.head = build_head(head)

    def run(self):
        print('Backbone的输出:',self.backbone.new_name)
        print('Head的输出', self.head.new_name)

cook = SampleModel(**cfg)
cook.run()
# 输出结果为:
Backbone的输出: test_backbone
Head的输出 test_newhead

我想这个应该是把注册器的原理,逻辑,应用方式是解释得很到位了吧,如果还有不懂的可以评论、私信我~

注册器其实就是一个字典

字典谁都懂,但是怎么用到实际工程项目中还是需要一些功底的。这是我自己实现的一个demo版的注册器,只是给大家理解下它的核心逻辑和用法而已,感兴趣的人可以看下我的这篇文章或者去扒拉下mmcv的源码。

如果这篇文章对你有帮助,麻烦点赞、收藏、关注一波,也可以关注我的微信公众号:CV伍六七,不定期分享工作总结,论文解读,考研经历,学习成长等,全网同名。

CSDN博客:CV伍六七

知乎:CV伍六七

掘金:CV伍六七

我的微信公众号:CV伍六七
理解MMLAB开源代码就从这开始:注册器Registry_第1张图片

你可能感兴趣的:(深度学习,计算机视觉,人工智能)