mmlab的注册器机制demo

# demo/registry_import_all.py
from demo.registry_ssd import Toy_SSD
from demo.registry_yolo import Toy_Yolo

类的装饰器在文件导入的时候会起作用,会对其进行装饰函数的内容(注册:在registry_model内,设置了从str到类的映射)
将所有的导入工作放入到一个文件中,使用时只需要导入这一个文件,即导入了所有的模型文件,完成所有模型的注册
使用的是类的对象的方法去装饰类本身,被修饰类将作为registry_model的参数传入到装饰函数中去

# demo/registry_ssd.py
from registry_demo_root import MODEL
import torch.nn as nn
  
@MODEL.registry_model
class Toy_SSD(nn.Module):
    def __init__(self, input, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        print(f"model of input : {input}")
        self.input = input
    
    def forward(self, x):
        return x * self.input
# demo/registry_yolo.py
from registry_demo_root import MODEL
import torch.nn as nn

@MODEL.registry_model
class Toy_Yolo(nn.Module):
    def __init__(self, input, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        print(f"model of input : {input}")
        self.input = input
    
    def forward(self, x):
        return x * self.input

根注册器:本质为注册器的一个对象
在注册模型时,导入的MODEL是统一的一个对象,其id唯一,使得注册模型时,使用的是同一个对象的_modele_dict属性,维护了同一个从str到类的映射

# demo/registry_demo_root.py
from typing import Type
class Registry:
    def __init__(self) -> None:
        self._modele_dict = dict()
        
    def registry_model(self, model: Type):
        self._modele_dict[model.__name__] = model
        
        return model
    
    def get(self, model_str: str):
        return self._modele_dict[model_str]

MODEL = Registry()

在使用模型时,使用的也是同一个MODEL对象,,调用该对象的get方法,从该对象的model_dict中获取该类作为返回值

# demo/use_model.py
import registry_import_all              # 将所有的模型都注册到root.py文件中的MODEL对象(id唯一)的_modele_dict字典中去
from registry_demo_root import MODEL    # 从root.py中导入该对象(该对象的id唯一)


registry = MODEL                        # 两者的id相同

model_str = 'Toy_SSD'

obj_cls = registry.get(model_str)       # 获取从字符串到类的映射

model = obj_cls(8888)                   # 实例化一个对象 
result = model(222)                     # 模型的前向传播
print(f"the model is ok, the foward fun's output is {result}")

你可能感兴趣的:(MMLab,python,深度学习)