最近,接触到了MMCV框架,发现MMCV框架为了方便更换backbone,优化器,学习策略等功能模块,引入了一种注册机制(Registry)的方法,可以有效的管理深度学习框架中的内容。同时,也方便用户通过外部接口灵活的搭配来组建自己的网络。为了深入了解注册机制的原理,我对此进行了学习和整理,并分享给读者朋友。
本文主要分为四个部分:
第一部分是简单介绍下python的闭包原理,闭包是装饰器的基础。
第二部分是简单介绍下python的装饰器原理,注册机制是装饰器的应用场景。
第三部分是简单介绍注册机制,并附上python代码示例。
第四部分是简单介绍MMCV框架注册机制代码,并附上相关代码注释。
由于我学的浅显,如果理解和您有偏差,则望各位大佬及时指出。最后如果您觉得对您有帮助的话,可以给小弟一个赞。 ⌣ ¨ \ddot\smile ⌣¨
参考资料如下所示:
1.Python 函数装饰器|菜鸟教程
2.理解闭包的概念,作者:alpha_panda。
3.Python3 命名空间和作用域
4.mmsegment定制模型(八),作者:alex1801。
5. Registry注册机制,作者:~HardBoy~。
6.理解 Python 装饰器看这一篇就够了,作者:刘志军。
在开始学习之前,需要提前了解2个概念,如下所示。
(1)对于python而言,一切皆是对象,包括函数自己也是对象。因此,函数是可以赋值给变量,并通过变量进行函数调用的。(函数也是可以赋值给函数的)
def base():
print("这是基础函数")
def derived():
print("这是派生函数")
# 以变量的形式调用
Var = base
Var()
# 另外,函数也是可以赋值给函数的,以函数的形式调用
derived = base
derived()
def base():
print("这是基础函数")
def derived(func):
func()
print("这是派生函数")
# 以函数的形式调用
derived(base)
def outsideFunc():
print("这是外部函数")
# 在派生函数内定义基本函数
def insideFunc():
print("这是内部函数")
return insideFunc
Var = outsideFunc()
print(Var)
Var()
本文关于闭包的介绍相对简单,具体深入细节可以参考博客:理解闭包的概念,解释的非常详细,还列出了闭包函数容易错误的点,本文重点是注册函数,因此仅是简单介绍。
概念: 一个函数可以引用作用域外的变量,同时也可以在函数内部进行使用的函数,称之为闭包函数。而这个作用域外的变量,称之为自由变量。这个自由变量不同于C++的static修饰的静态变量,也不同于全局变量,它是和闭包函数进行绑定的。
注意两点:
1.闭包函数与闭包函数之间,自由变量不会相互干扰。
2.闭包函数的自由变量会传到下一次闭包函数中。
def outside():
# 此处obj为自由变量
obj = []
# inside就被称为闭包函数
def inside(name):
obj.append(name)
print(obj)
return inside
Lis = outside()
Lis("1")
# 对闭包函数自由变量的修改会传到下一次闭包函数中。
Lis("2")
Lis("3")
print("-"*20)
# 闭包函数和闭包函数之间的自由变量不会相互影响。
Lis2 = outside()
Lis2("4")
Lis2("5")
Lis2("6")
从上例中可以看出,obj是针对inside闭包函数的自由变量。每次调用inside闭包函数时,都会给obj添加一个新值,并且会传到下一次inside闭包函数的调用中。
此处,obj自由变量相对于inside闭包函数的作用相当于一个“全局变量”,只不过作用域仅是针对于包含自由变量和闭包函数的外部环境而言,上例中指的是outside的作用域。python将这种作用域,称为闭包函数外的作用域(Enclosing)。
补充知识:
python有四种作用域分别是:
局部作用域(Local):最内层,包含局部变量,比如一个函数/方法内部。
闭包函数外的作用域(Enclosing):包含了非局部(non-local)也非全局(non-global)的变量。比如两个嵌套函数,一个函数(或类) A 里面又包含了一个函数 B ,那么对于 B 中的名称来说 A 中的作用域就为 nonlocal。
全局作用域(Global):当前脚本的最外层,比如当前模块的全局变量。
内建作用域(Built-in):包含了内建的变量/关键字等,最后被搜索。
规则顺序如下图所示:
如上图所示,python在局部作用域找不到对应的函数变量,便会去局部外的局部找(例如闭包),再找不到就会去全局作用域找,再者去内建作用域中找。此处更详尽的内容可参考博客Python3 命名空间和作用域。
了解闭包原理后,再来看装饰器的本质。装饰器其实就是通过闭包原理来封装函数,并返回函数的高级函数。封装的目的是为了在保持原有函数功能的基础上,扩充额外的功能。具体可以参考如下示例。
def add(arr1, arr2):
return arr1 + arr2
def decorator(func):
def wrapper(arr1, arr2):
print("实现新的功能,例如数据加1的功能")
arr1 += 1
arr2 += 1
return func(arr1, arr2)
return wrapper
print(add(2, 3))
add = decorator(add)
print(add(2, 3))
从代码中可以看到,我们希望在保持数据相加功能的基础上,再添加一个参数自加一的新功能。如果按照往常,我们需要在之前的函数内添加新的内容,这样的话就会破坏原有的结构。因此,在既需要保留原有函数结构的基础上,额外的添加新功能,就可以通过装饰器来进行实现。上述decorator就是一个装饰器函数,下面add = decorator(add)就是他的调用方式,不过在python中,可以使用@语法来简化调用语句。如下所示:
def decorator(func):
def wrapper(arr1, arr2):
print("实现新的功能,例如数据加1的功能")
arr1 += 1
arr2 += 1
return func(arr1, arr2)
return wrapper
@decorator
# @decorator就等价于add = decorator(add)
def add(arr1, arr2):
return arr1 + arr2
print(add(2, 3))
可以看到@decorator就等价于add = decorator(add),”@“会把它修饰的函数作为参数传给装饰器函数
补充一个小点:
在经过装饰器函数后,其实add函数的内容已经变成wrapper函数了,所以它的其他相关描述内容也变成了wrapper的描述内容,例如__name__。因此,如果想要不修改原有描述内容的话,可以借助functools函数包中的装饰函数@wraps(func)来实现。这个函数可以复制原有函数的描述内容。
# 不加@wraps的情况下
def decorator(func):
def wrapper(arr1, arr2):
print("实现新的功能,例如数据加1的功能")
arr1 += 1
arr2 += 1
return func(arr1, arr2)
return wrapper
def add(arr1, arr2):
return arr1 + arr2
#原本add函数
print(add.__name__)
#通过装饰器修饰后
add = decorator(add)
print(add.__name__)
from functools import wraps
# 添加@wraps的情况下
def decorator(func):
@wraps(func)
#@wraps(func) 等同于 wrapper = wraps(func)(wrapper)
def wrapper(arr1, arr2):
print("实现新的功能,例如数据加1的功能")
arr1 += 1
arr2 += 1
return func(arr1, arr2)
return wrapper
def add(arr1, arr2):
return arr1 + arr2
#原本add函数
print(add.__name__)
#通过装饰器修饰后
add = decorator(add)
print(add.__name__)
概念: 注册机制主要是实现用户输入的字符串到所需函数或者类的映射,方便项目管理和用户使用。注册机制可以通过python装饰器来构建映射关系。比如:MMCV也是通过装饰器的方法来完成的。
完成注册机制主要有三个步骤:
(1)编写注册机制类。
(2)实例化一个注册机制的对象,即构建注册表。
(3)通过装饰器原理来往注册表添加内容,即实现内容注册
class Registry:
def __init__(self, name=None):
# 生成注册列表的名字, 如果没有给出,则默认是Registry。
if name == None:
self._name = "Registry"
self._name = name
#创建注册表,以字典的形式。
self._obj_list = {}
def __registry(self, obj):
"""
内部注册函数
:param obj:函数或者类的地址。
:return:
"""
#判断是否目标函数或者类已经注册,如果已经注册过则标错,如果没有则进行注册。
assert(obj.__name__ not in self._obj_list.keys()), "{} already exists in {}".format(obj.__name__, self._name)
self._obj_list[obj.__name__] = obj
def registry(self, obj=None):
"""
# 外部注册函数。注册方法分为两种。
# 1.通过装饰器调用
# 2.通过函数的方式进行调用
:param obj: 函数或者类的本身
:return:
"""
# 1.通过装饰器调用
if obj == None:
def _no_obj_registry(func__or__class, *args, **kwargs):
self.__registry(func__or__class)
# 此时被装饰的函数会被修改为该函数的返回值。
return func__or__class
return _no_obj_registry
#2.通过函数的方式进行调用
self.__registry(obj)
def get(self, name):
"""
通过字符串name获取对应的函数或者类。
:param name: 函数或者类的名称
:return: 对应的函数或者类
"""
assert (name in self._obj_list.keys()), "{} 没有注册".format(name)
return self._obj_list[name]
这个注册机制类主要包含三个成员函数,分别是__registry,registry,get和2个成员变量self._name和self._obj_list。
成员变量:
1.self._name变量:表示这个注册表的名称,如果没有给予,则默认为Registry。
2.self._obj_list变量:以字典的形式表示注册表,即字符串与对应函数名的映射关系。
成员函数:
1.registry函数:以两种方式对传入的函数进行注册,一种是通过闭包函数 _no_obj_registry,来实现对自由变量self._obj_list的修改。另一种则是直接通过传入registry函数的obj的参数完成注册。这两种方式具体实现功能是通过__registry函数来实现注册的
2.__registry函数:对传入进来的函数参数进行注册,如果存在则报错,如果不存在则完成注册。
3.get函数:通过查找注册表,实现从字符串name到对应函数或类名称的映射,返回对应名称的函数或类。
基于注册机制类,来实例化一个对象,这个对象就是我们需要的注册表。
# 生成注册表
REGISTRY_LIST = Registry("REGISTRY_LIST")
通过装饰器原理来往注册表添加内容,即实现内容注册。在下例中,通过语句@REGISTRY_LIST.registry()来实现对create_by_decorator函数的注册。
@REGISTRY_LIST.registry()等价于
test_by_decorator = REGISTRY_LIST.registry()(test_by_decorator),
即_no_obj_registry(test_by_decorator)
# 通过装饰器调用
@REGISTRY_LIST.registry()
# @REGISTRY_LIST.registry()等价于test_by_decorator = REGISTRY_LIST.registry()(test_by_decorator),即_no_obj_registry(test_by_decorator)
def create_by_decorator():
print("通过装饰器完成注册的函数")
def create_by_function():
print("直接通过registry函数进行注册")
#当然也可以直接通过传入registry函数进行注册。
REGISTRY_LIST.registry(create_by_function)
#通过字符串来获取对应函数名称的函数
test1 = REGISTRY_LIST.get("create_by_decorator")
test1()
test2 = REGISTRY_LIST.get("create_by_function")
test2()
由于本人是在windows下运行mmcv框架的。因此,我的registry.py文件的文件路径是F:\SegFormer-master\mmcv-1.2.7\mmcv\utils\registry.py,读者可以根据自身情况查找自己项目中mmcv的registty.py文件位置,Linux应该是安装的mmcv包下。整体代码如下。后面我们会拆开来细看。
import inspect
import warnings
from functools import partial
from .misc import is_seq_of
class Registry:
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
else:
assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
具体可以看到,mmcv的Registry类包含两个成员变量self._name和self._module_dict,以及主要的6个成员函数,name函数、module_dict函数、get函数、_register_module函数、deprecated_register_module函数和register_module函数。我们分别进行简单介绍。
成员变量:
1.self._name变量:表示这个注册表的名称。
2.self._module_dict变量:以字典的形式表示注册表,即字符串与对应函数名的映射关系。
成员函数:
1.name函数、module_dict函数都是装饰器@property修饰的。Python内置的@property装饰器就是负责把一个方法变成属性调用。
2.register_module函数:完成对目标类的注册,具体代码如下,函数含义已经进行注释。
def register_module(self, name=None, force=False, module=None):
"""注册一个模型
类名称将被添加到变量self._module_dict中, 该变量的键值是类别名或者专属名字。
它可以通过装饰器或者函数直接调用。
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): 要注册的模块名称。如果未指定,则将使用类名。
force (bool, optional): 是否用相同的名称重写现有的类。默认值:False。
module (type): 要注册的模块类。
"""
#---------------------------------------------------------------------
#判断输入force参数是否正确。
#---------------------------------------------------------------------
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
#---------------------------------------------------------------------
#注意:这是一个与旧api兼容的演练,而它可能会引入意想不到的错误。
#---------------------------------------------------------------------
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
#---------------------------------------------------------------------
#判断module是否存在,如果存在则直接进行注册。并返回module
#---------------------------------------------------------------------
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
#---------------------------------------------------------------------
#判断输入的name参数是否正确
#---------------------------------------------------------------------
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
#---------------------------------------------------------------------
#如果module不存在,则通过装饰器的方式进行注册。
#---------------------------------------------------------------------
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
def _register_module(self, module_class, module_name=None, force=False):
"""
具体实现注册方法。
:param module_class:需要注册的函数本身
:param module_name:需要注册的函数名称。默认为None
:param force:是否重写已经存在的函数,默认为False
"""
#---------------------------------------------------------------------
#判断module是否是class类。不是类则报错
#---------------------------------------------------------------------
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
#---------------------------------------------------------------------
#判断module_name是否存在,不存在则默认函数本身名称
#---------------------------------------------------------------------
if module_name is None:
module_name = module_class.__name__
#---------------------------------------------------------------------
#判断module_name是否是个字符串,或者列表
#---------------------------------------------------------------------
if isinstance(module_name, str):
module_name = [module_name]
else:
assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
#---------------------------------------------------------------------
#针对列表中的字符串进行注册。
#---------------------------------------------------------------------
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
#完成注册
self._module_dict[name] = module_class
4.get函数:通过查找注册表,实现从字符串name到对应函数名称的映射,返回对应名称的函数。
def get(self, key):
"""或者注册表的键值
Args:
key (str): 键值必须是字符串
Returns:
class: 键值对应的类.
"""
return self._module_dict.get(key, None)
5.deprecated_register_module函数:弃用注册模块。这个没看明白。
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
通过mmcv的Registry简单注册一下。具体实现如下。当然MMCV添加注册模块可不是这样做的。具体如何添加模块,可以参考csdn博客mmsegment定制模型(八。
if __name__ == "__main__":
from torch.optim.adam import Adam
registry_list = Registry("OPTIM")
registry_list.register_module(name="registry_adam", module=Adam)
optim = registry_list.get("registry_adam")
print(optim)
print(registry_list.module_dict)
本文分别总结python闭包,装饰器和注册机制的原理,并列出了代码和输出结果。不过,本文还是仅仅只涉及了最粗浅的部分。在学习各个参考博客时,发现关于闭包和装饰器的内容远远不止这些。如果,您对这些方面有兴趣可以进入参考博客中继续学习。最后,感谢您的阅读。