Python 注册器的使用及OpenMMDetection中的实例

Python 注册器

对于稍大一些的Python项目,很可能用到使用注册器的时候,因为创建某个模块的时候可能需要很多的参数进行输入操作,比如OpenMMlab中对于某个模型的创建,使用的就是注册器的机制。从模型名就能创建对应的模型当然可以使用字典的方法,然后读取字典中内容进行解析。但是问题是当需要创建新的模型的时候,就必须对解析代码进行相应的手动维护。

如果使用注册器机制,那么就可以只维护需要注册的模块的路径就可以了。

最简单的话来说,注册器其实就是一个负责记录着各种函数名字和对应函数对象的字典(其本身就会定义一个_dict,可以set_item和get_item),当我们通过注册器进行实例创建的时候,只需要往里面传入对应函数名字和参数。而当有新的对象的定义的需要添加到注册器,也只需要让注册器找到对应代码就可以,不需要进行任何解析代码的手动改动。

注册器的优点在于:

  • 利于扩展代码,当某段代码需要正佳新函数和类的时候,可以复用之前的逻辑
  • 创建实例时通过读取配置文件读取对应的参数进行创建。

简单例子:

class RegisterMachine(object):
    def __init__(self, name):
        # name of register
        self._name = name
        self._name_method_map = dict()
    
    def register(self, obj=None):
        # obj == None for function call register
        # otherwise for decorator way
        if obj != None:
            name = obj.__name__
            self._name_method_map[name] = obj
        
        else:
            def wrapper(func):
                name = func.__name__
                self._name_method_map[name] = func
                return func
            return wrapper

    def get(self, name):
        return self._name_method_map[name]

if __name__ == "__main__":
    register_obj = RegisterMachine("register")
    # decorate method
    @register_obj.register()
    def say_hello_with(name):
        print("Hello, {person}!".format(person=name))

    def say_hi_with(name):
        print("Hi, {person}!".format(person=name))

    register_obj.get("say_hello_with")("Peter")
    # function call method
    register_obj.register(say_hi_with)
    register_obj.get("say_hi_with")("John")

注意:@register_obj.register()是python修饰器用法,其等价于@register_obj.register(say_hello_with(name))。

从上面例子可以看出,注册器RegisterMachine是一个类,其中包含里一个字典,保存的就是各种函数对象,通过RegisterMachine.register(self, func)进行字典中对象的添加,添加后的对象可以通过RegisterMachine.get(self, name)通过函数名字进行访问。这样子就可以实现新的代码进行添加的时候,只需要将新代码通过RegisterMachine.register(self, func)就可以添加,比如代码中的say_hi_with函数。

当我们需要使用注册器的时候,既可以将其作为修饰器进行注册,也可以直接显式调用进行注册。

同时,我们可以对RegisterMachine的函数进行重载或者添加来实现更复杂的功能。

接下来看一个来自OpenMMDetection的注册器定义和使用:在OpenMMLab中使用注册器机制来对不同模型的创建(从backbone到head等都使用了注册器来进行)。

class Registry:
    """A registry to map strings to classes.

    Args:
        name (str): Registry name.
    """

    def __init__(self, name):
        self._name = name
#记录用的dict
        self._module_dict = dict()

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={self._name}, ' \
                     f'items={self._module_dict})'
        return format_str

    @property#@property修饰是保证只读属性而添加的
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        return self._module_dict.get(key, None)

#注册函数本体
    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        else:
            assert is_seq_of(
                module_name,
                str), ('module_name should be either of None, an '
                       f'instance of str or list, but got {type(module_name)}')
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{name} is already registered '
                               f'in {self.name}')
            self._module_dict[name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
#partial作用是冻结self.deprecated_register_module的force参数为force的值
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

#注册时候调用入口
    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
#name可以是None或者str对象,当出现重名且force为True的时候将会强制注册新的对象
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)
#当显式调用的时候,会转到_register_module进行注册
        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')
#当将注册器使用修饰器方式进行使用的时候,
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

看完以上内容之后,可见注册器机制本质就是利用了python的语言特点(修饰器、函数其实是类的实例化对象、万物皆为对象的特点),对dict进行花式操作的过程。

你可能感兴趣的:(编程技巧,pytorch,python)