mmdetecion 中类注册的实现(@x.register_module())

mmdetecion 中类注册的实现(@DETECTORS.register_module())

mmdetecion中,每一个类使用注册器(Registry)来进行注册,只要引入过该文件,就可以自动注册,在配置文件中使用(type=‘XXX’)的方式引用,例如

@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
    """Implementation of `Faster R-CNN `_"""

    def __init__(self,
                 backbone,
                 rpn_head,
                 roi_head,
                 train_cfg,
                 test_cfg,
                 neck=None,
                 pretrained=None,
                 init_cfg=None):
        super(FasterRCNN, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
            init_cfg=init_cfg)

上述是fasterRCNN的实现,@DETECTORS.register_module()是它的注册语法糖

python@语法糖:

首先来看一个例子

def funcA(fn):
	do_something()
	fn()
	do_someting()
	return "haha"

def funcB():
	return "B"

funcA和funcB是两个单独的函数,我们现在用@funcA来修饰B

def funcA(fn):
	do_something()
	fn()
	do_someting()
	return "haha"
	
@funcA
def funcB():
	return "B"

经过funcA的修饰,此时的B与以下代码等价:

def funcA(fn):
	do_something()
	fn()
	do_someting()
	return "haha"

def funcB():
	return "B"

funcB = funcA(funcB)

可以看到,@的作用就是修饰B,将B传入A中,将A的返回值重新赋值给B
在上述代码中,我们的funcB本身是一个返回 “B” 的函数,经过A的修饰,funcB变成了传入A之后的返回值,也就是 “haha”,函数类型变成了一个字符串类型

当然,上面只是一个举例,正常情况下我们都希望只是修饰以下B,不希望B的本身发生变化,以致于类型都不一样,我们可以将A的返回值改为一个函数

def funcA(fn):
    do_something()
    fn()
    return fn

@funcA
def funcB():
	return "B"

这样做,我们就可以在注册的时候,也就是引用funcB的时候就可以执行funcA的代码,同时不改变funcB的内容

同样,当我们修饰一个类的时候,也是同样的

def funcA(cls):
    do_something()
    return cls

@funcA
class classB():
	def __init__(self):
			return "B"

在这种方式的修饰下,classB本身不会发生变化,会在编译之后执行do_something()
注册就是在这里实现的

我们可以看原生实现了,原生代码为

@DETECTORS.register_module()

之前我们都是用函数名来修饰,在这里使用了函数调用来修饰,实现方式为:

def register_module(self, name=None, force=False, module=None):
        
        #
        省略掉了具体的注册
        #
		register()
		
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

也就是说,@DETECTORS.register_module() 实际上返回了一个函数,_register才是实际上修饰的代码,在执行完成注册之后返回原本的类

注册

真正的注册代码结构就比较简单了,去掉一些assert和判断以外,剩下的其实就一行

self._module_dict[name] = module_class

维护一个module_dict的列表,将class name作为key存储下来

build

以neck的配置文件为例:

neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),

在build进行时候,只需要取出type字段,将剩余的字段传入class中即可

obj_type = args.pop('type') #取出type字段
if isinstance(obj_type, str):
    obj_cls = registry.get(obj_type)
    if obj_cls is None: #判断在不在注册列表里
        raise KeyError(
            f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
    obj_cls = obj_type
else:
    raise TypeError(
        f'type must be a str or valid type, but got {type(obj_type)}')
try:
    return obj_cls(**args) #使用类

你可能感兴趣的:(mmdetection,深度学习,python,机器学习)