之前写过对Registry,Runner,Hook的介绍,但是Registry这个有读者反映没特别理解,我重新回顾了下发现之前的文章写得有点冗长,加上自己后续还有一些其他系列干货会分享出来,因此今天就特地重新整理下,写一篇新的文章出来,力保大家看了醍醐灌顶。
工作日趋繁忙,上一篇讲HOOK
的文章获得了大家的赞同受宠若惊,虽然才疏学浅,今天就接着上次的内容,讲一下MMDetection
中注册器Registry
,希望能帮助到大家理解这个框架。
首先,要明确注册器的使用目的就是为了在算法训练、调参中通过直接更改配置文件。因为注册器其实就是一个字典,完成了字符串->类的映射,它是一个算法训练框架实现模块化手段,不会在代码中引入hardcode
。比如在MMDetection
中有下面这些定义的组件。
from mmcv.utils import Registry, build_from_cfg
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
其中,模型被拆解成了Backbones
、Necks
、Roi_extractors
、Shared_Heads
、Heads
、Losses
、Detectors
几个部分,当时如果我们需要添加更多的模块也是可以的,只需要实例化我们各自的注册器就行。
mmcv
中的registry
代码很长,因为开源代码需要做各种类型检查,还需要提供很多种接口。我自己手写了一个最迷你的注册器,并且用实际的代码来带大家看下注册器是怎么做成组件化的。
class Registry:
def __init__(self, name):
self.name = name
self.module_dict = {} # 记录一个名字->类的字典
def _register_module(self, module_class):
module_name = module_class.__name__
self.module_dict[module_name] = module_class
def register_module(self):
def _register(cls):
self._register_module(module_class=cls)
return cls
return _register
有个注册器这个类,那么我们就可以将其组件化。比如我组件化出Backbone
和Head
这两个东西。就可以这么做:
# 定义一个根据参数实例化对象的方法,这含义就是说配置文件里的Type字段需要和你实际的类名对应上。然后实例化出一个对象。
def build_from_cfg(cfg, registry):
obj_type = cfg.pop('type')
obj_cls = registry.module_dict[obj_type]
return obj_cls(**cfg)
# 生成BACKBONE和HEAD组件
BACKBONE = Registry('backbone')
HEAD = Registry('head')
def build_backbone(cfg):
return build_from_cfg(cfg, BACKBONE)
def build_head(cfg):
return build_from_cfg(cfg, HEAD)
通过上述代码,我们就有了两个组件,backbone
和head
,并且有了对应的build
方法去实例化相应的对象。比如我们新增backbone
和head
的时候,可以这么操作。
@BACKBONE.register_module()
class ResNet(object):
def __init__(self, new_name):
self.new_name = new_name
@HEAD.register_module()
class NewHead(object):
def __init__(self, new_name):
self.new_name = new_name
通过注册器来实现组件化在MMLAB
代码里是最常见的,当然原始的代码不仅仅是实现了注册器的方式,感兴趣的同学可以去琢磨下源码,但是到这一步基本上其精髓已经写出来了。可以看下这个组件里是啥东西
for k, v in BACKBONE.module_dict.items():
print(k, v)
# 输出:ResNet
for k, v in HEAD.module_dict.items():
print(k, v)
# 输出:NewHead
可以看到注册器中已经有了这个映射了,那么我们就可以实现通过配置文件来构件模型。
from utils.registry import build_backbone, build_head
cfg = dict(
backbone=dict(type='ResNet', new_name='test_backbone'),
head=dict(type='NewHead', new_name='test_newhead')
)
class SampleModel():
def __init__(self, backbone, head):
self.backbone = build_backbone(backbone)
self.head = build_head(head)
def run(self):
print('Backbone的输出:',self.backbone.new_name)
print('Head的输出', self.head.new_name)
cook = SampleModel(**cfg)
cook.run()
# 输出结果为:
Backbone的输出: test_backbone
Head的输出 test_newhead
我想这个应该是把注册器的原理,逻辑,应用方式是解释得很到位了吧,如果还有不懂的可以评论、私信我~
注册器其实就是一个字典
字典谁都懂,但是怎么用到实际工程项目中还是需要一些功底的。这是我自己实现的一个demo
版的注册器,只是给大家理解下它的核心逻辑和用法而已,感兴趣的人可以看下我的这篇文章或者去扒拉下mmcv
的源码。
如果这篇文章对你有帮助,麻烦点赞、收藏、关注一波,也可以关注我的微信公众号:CV伍六七,不定期分享工作总结,论文解读,考研经历,学习成长等,全网同名。
CSDN博客:CV伍六七
知乎:CV伍六七
掘金:CV伍六七