在一个稍大一点的python项目中,我们很有可能会用到注册器(register)。这个注册器不是用户账号注册的模块,而是项目中注册模块的一个模块。举个例子,一个深度学习项目可能支持多种模型;具体使用哪种模型可能是用户在配置文件中指定的。最简单的实现方式,就是维护一个模型名称->模型类
的字典。但每当你增加一个模型时,这个字典就需要手动维护,比较繁琐。本文介绍一种注册器的模块,你需要维护的是需要注册的模块的代码路径(相对简介些)。
这个模块在我们的开源项目Delta中也有使用。
点这里看完整源代码
models/model.py
:
class Model: pass @Registers.model.register class Model1(Model): pass @Registers.model.register class Model2(Model): pass @Registers.model.register class Model3(Model): pass |
class Register: def __init__(self, registry_name): self._dict = {} self._name = registry_name def __setitem__(self, key, value): if not callable(value): raise Exception(f"Value of a Registry must be a callable!\nValue: {value}") if key is None: key = value.__name__ if key in self._dict: logging.warning("Key %s already in registry %s." % (key, self._name)) self._dict[key] = value def register(self, target): """Decorator to register a function or class.""" def add(key, value): self[key] = value return value if callable(target): # @reg.register return add(None, target) # @reg.register('alias') return lambda x: add(target, x) def __getitem__(self, key): return self._dict[key] def __contains__(self, key): return key in self._dict def keys(self): """key""" return self._dict.keys() |
补充一个知识点,@
是python的装饰器语法糖。
@decorate def func(): |
等价于:
func = decorate(func) |
这里,Register
类似于一个dict
(实际上是有一个_dict
属性),可以set_item
和get_item
。关键是register
函数,它可以作为装饰器,注册一个函数或者一个类。例如:
@register_obj.register class Modle1: |
等价于register_obj.register(Model1)
,最终执行的是add(None, Model1)
。
而:
@register_obj.register("model_one") class Model1: |
实际上是register_obj.register("model_one")(Model1)
,最终执行的是add("model_one", Model_1)
。
总结下:Register
类保存了名称->模块
的数据,且提供了方便的注册装饰器。
class Registers: def __init__(self): raise RuntimeError("Registries is not intended to be instantiated") model = Register('model') |
Registers
保存了所有的Register
对象。
在模块代码中加入注册装饰器之后,我们还需要把这些模块实际地导入,才能让这些子模块加入进注册器中。
一般大家会首先想到import
。比如这里可以直接import models.models
就可以让注册装饰器起作用。
但是import
子模块这种形式很有可能导致循环引用的问题。为了避免循环引用,我们可以在代码入口处,统一地动态引入所有子模块。动态导入包使用importlib
。
MODEL_MODULES = ["models"] ALL_MODULES = [("models", MODEL_MODULES)] def _handle_errors(errors): """Log out and possibly reraise errors during import.""" if not errors: return for name, err in errors: logging.warning("Module {} import failed: {}".format(name, err)) def import_all_modules_for_register(custom_module_paths=None): """Import all modules for register.""" modules = [] for base_dir, modules in ALL_MODULES: for name in modules: full_name = base_dir + "." + name modules.append(full_name) if isinstance(custom_module_paths, list): modules += custom_module_paths errors = [] for module in modules: try: importlib.import_module(module) except ImportError as error: errors.append((module, error)) _handle_errors(errors) |
最后我们使用下我们的注册器模块:
from register import import_all_modules_for_register from register import Registers print("Registers.model._dict before: ", Registers.model._dict) import_all_modules_for_register() print("Registers.model._dict after: ", Registers.model._dict) |
输出:
Registers.model._dict before: {} Registers.model._dict after: {'Model': |
可以看到,需要的模块已经加入到注册器中。
这个模块在我们的开源项目Delta中也有使用。