mmdetecion中,每一个类使用注册器(Registry)来进行注册,只要引入过该文件,就可以自动注册,在配置文件中使用(type=‘XXX’)的方式引用,例如
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
"""Implementation of `Faster R-CNN `_"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
init_cfg=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
上述是fasterRCNN的实现,@DETECTORS.register_module()是它的注册语法糖
首先来看一个例子
def funcA(fn):
do_something()
fn()
do_someting()
return "haha"
def funcB():
return "B"
funcA和funcB是两个单独的函数,我们现在用@funcA来修饰B
def funcA(fn):
do_something()
fn()
do_someting()
return "haha"
@funcA
def funcB():
return "B"
经过funcA的修饰,此时的B与以下代码等价:
def funcA(fn):
do_something()
fn()
do_someting()
return "haha"
def funcB():
return "B"
funcB = funcA(funcB)
可以看到,@的作用就是修饰B,将B传入A中,将A的返回值重新赋值给B
在上述代码中,我们的funcB本身是一个返回 “B” 的函数,经过A的修饰,funcB变成了传入A之后的返回值,也就是 “haha”,函数类型变成了一个字符串类型
当然,上面只是一个举例,正常情况下我们都希望只是修饰以下B,不希望B的本身发生变化,以致于类型都不一样,我们可以将A的返回值改为一个函数
def funcA(fn):
do_something()
fn()
return fn
@funcA
def funcB():
return "B"
这样做,我们就可以在注册的时候,也就是引用funcB的时候就可以执行funcA的代码,同时不改变funcB的内容
同样,当我们修饰一个类的时候,也是同样的
def funcA(cls):
do_something()
return cls
@funcA
class classB():
def __init__(self):
return "B"
在这种方式的修饰下,classB本身不会发生变化,会在编译之后执行do_something()
注册就是在这里实现的
我们可以看原生实现了,原生代码为
@DETECTORS.register_module()
之前我们都是用函数名来修饰,在这里使用了函数调用来修饰,实现方式为:
def register_module(self, name=None, force=False, module=None):
#
省略掉了具体的注册
#
register()
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
也就是说,@DETECTORS.register_module()
实际上返回了一个函数,_register
才是实际上修饰的代码,在执行完成注册之后返回原本的类
真正的注册代码结构就比较简单了,去掉一些assert和判断以外,剩下的其实就一行
self._module_dict[name] = module_class
维护一个module_dict的列表,将class name作为key存储下来
以neck的配置文件为例:
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
在build进行时候,只需要取出type字段,将剩余的字段传入class中即可
obj_type = args.pop('type') #取出type字段
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None: #判断在不在注册列表里
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args) #使用类