detecron2中的注册机制

文章目录

    • 一、为什么使用注册类
    • 二、注册类的实现

传送门 [Detectron2] 01-注册机制 Registry 实现

一、为什么使用注册类

以下转自知乎 https://zhuanlan.zhihu.com/p/93835858

对于detectron2这种,需要支持许多不同的模型的大型框架,理想情况下所有的模型的参数都希望写在配置文件中,那问题来了,如果我希望根据我的配置文件,决定我是需要用VGG还是用ResNet ,我要怎么写呢?

如果是我,我可能会写出这种可扩展性超级低的暴搓的代码:

if class_name == 'VGG':
    model = build_VGG(args)
elif class_name == 'ResNet':
    model = build_ResNet(args)

但是如果用了注册类,代码就是这样的:

class_name = 'VGG' # 'ResNet'
model = model_registry(class_name)(args)

可以看到代码的可扩展性变得非常强了

二、注册类的实现

注册类的源码

class Registry(object):
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name: str) -> None:
        """
        Args:
            name (str): the name of this registry
        """
        self._name: str = name
        self._obj_map: Dict[str, object] = {
     } #创建字典,str对应一个函数

    def _do_register(self, name: str, obj: object) -> None:
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(
            name, self._name
        )
        self._obj_map[name] = obj

    def register(self, obj: object = None) -> Optional[object]:
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class: object) -> object:
                name = func_or_class.__name__  # pyre-ignore
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call,回调函数可以获取函数名
        name = obj.__name__  # pyre-ignore
        print("name: ", name)
        self._do_register(name, obj) # 添加字典

    def get(self, name: str) -> object:
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError(
                "No object named '{}' found in '{}' registry!".format(name, self._name)
            )
        return ret

    def __contains__(self, name: str) -> bool:
        return name in self._obj_map

如何使用

# 创建一个Registry对象
registry_machine = Registry('registry_machine')
@registry_machine.register()
def print_hello_world(word):
    print("he world")
    print(word)

@registry_machine.register()
def print_hi_world(word):
    print("hi world")
    print(word)

# 其中cfg为所调用的函数名/类名
cfg = "print_hello_world"
# 相当与调用print_hello_world('hello world')
registry_machine.get(cfg)('hello world')

cfg = "print_hi_world"
registry_machine.get(cfg)('hello world2')

利用@装饰器,将函数传入到registry中,在registry内部获取回调函数的名字,并创建字典,字典对应 函数名:函数,每次装饰一个函数,字典就会添加一次。
如此,就可以通过函数名来找到对应的函数

回调函数获取函数名,可以通过 __name__获取


def hello23():
    print("hello2")

def printHello(hello):
    # hello()
    print("world")
    print(hello.__name__)

printHello(hello23)

你可能感兴趣的:(深度框架,Pytorch)