疫情在家办公,新Team这边习惯用MMLab开发网络,正好趁这段时间理解一下商汤大佬们的框架。我之前其实网络开发的比较少,主要是学习用的,而且开发网络基本是靠手写或者copy,用这种架构开发我是十分赞成的,上手快,不容易出错,而且在这个网络训练网络的时代,config作为深度网络的上位机确实是王道。Anyway, 作为学习者,还是要知道网络是怎么通过config搭建好的,才能将自己的网络迁移进来,否则灵活性太差了。这期总结一下mmsegmentation的搭建网络的方法。
关于config的分析就不多说了,继承形式的config,我们这里主要关心网络是如何形成的。
网络的搭建当然要从tools/train.py开始,整个main函数前面大部分都是在解析configs的配置并存到cfg对象,直到在196行开始终于开始用build_segmentor这个函数来建立模型。
追溯这个函数,可以找到mmseg/model文件夹下的builder.py,model文件中显然存放的是模型的结构文件,包括主干网、neck、检测头等,builder.py应该算作model的一个整体对外的接口。
在builder.py文件中,我们发现首先用Registry实例化MODELS,并且感觉像是继承了MMCV_MODELS,这个基类我们线猜测是MMLAB的模型库。然后将MODELS又传递给SEGMENTORS。
在build_segmentor中,SEGMENTORS使用build方法建立了模型,到目前为之,模型算子或者模块都没有显示出来,那核心就是这个注册表Registry类作了什么操作了。
我们找到Registry类,说明里面表明Registry是为了将字符串和类进行map,那懂了,Registry确实是注册表的意思。注册表是为了做什么的呢?注册表本质上是存储设置信息的一种数据库,说明Registry其实本质目的就是把config的信息传递到网络的类中。
我们看到的Registry的使用方法是如下这种形式,可以看到’model’是传入的name, MMCV_MODELS是传入的build_func,描述中可以看到,build_func如果没有给出,但是parent参数给出了,build_func会隐式继承parent传入的参数。
MODELS = Registry('models', parent=MMCV_MODELS)
到目前为之,我们还有两个问题没有搞清楚,第一个问题,这个注册表类是怎么生成模型的;第二个问题,父类注册表MMCV_MODELS里面又说了些啥。我们先解决第一个问题,看一下注册表类里面的build方法:
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
其他代码先不关注,只看build,我们发现build方法实际就是调用了build_func,而build_func实际就是你传入的父类注册表,现在两个问题又回到了一个问题,父类或者基类的这个注册表描述了什么。去找一下他的定义,我发现了这个父类已经到了/python3.6/site-packages/mmcv/cnn/builder.py 这个路径下了,说明我们已经接近他的核心部分了。进取以后,惊呆了,竟然是个环,没错,你没看错,又是Registry,只不过这次直接传入了build_func,而且将build_func传入的builid_model_froom_cfg同时定义好了。
MODELS = Registry('model', build_func=build_model_from_cfg)
我们主要要看一下,这个build_func在干些啥,非常清晰,这个函数直接输出的就是nn.modules,就是我们要的pytorch模型
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
回想前面的build方法(如下),传入的参数其实就是给build_model_from_cfg这个函数服务的,传入的主要是cfg,train_cfg和test_cfg,看起来应该是cfg参数是模型主参数,先做个大胆的推测,然后等待打脸(补充:回到train.py 你会发现,传入的是cfg.model,确实是模型主参数)~ 模型到底是砸建的呢?我们又可以看到,build_model_from_cfg函数里面出现了一个build_from_cfg,而且执行了一个for循环去遍历cfg,我们有理由相信这步就是为了形成模型的各个模块,查找一下build_from_cfg这个函数,这个又回到了Reigister那个类的文件中。
return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
分析一下这个函数,忽略前面一堆if,首先将cfg传给了arg,果然习惯用arg的人也不少啊哈哈~然后将default_args传入到args中,然后从args中pop出‘type’这个key对应的value,例如‘EncodeDecoder’,再将这个value传给registry.get方法。
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
谈起get方法就稍微有点复杂了,get其实实现了这个pop出来的value是一个什么样的任务,其中又用到了类的嵌套,我对这个一直没有搞清楚。Anyway,我们这步实现了提取对应模型的class
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
最后在build_from_cfg中用try方法实例化的这个类,从而生成了模型。如果我们去看对应的类的话,我们还会发现每个类的上面还有对应的装饰器方法,该装饰器方法会在实例化模型的过程中,将模型记录在对应的注册表类中。