mmsegmentation模型生成代码解析

前言

疫情在家办公,新Team这边习惯用MMLab开发网络,正好趁这段时间理解一下商汤大佬们的框架。我之前其实网络开发的比较少,主要是学习用的,而且开发网络基本是靠手写或者copy,用这种架构开发我是十分赞成的,上手快,不容易出错,而且在这个网络训练网络的时代,config作为深度网络的上位机确实是王道。Anyway, 作为学习者,还是要知道网络是怎么通过config搭建好的,才能将自己的网络迁移进来,否则灵活性太差了。这期总结一下mmsegmentation的搭建网络的方法。

框架分析

关于config的分析就不多说了,继承形式的config,我们这里主要关心网络是如何形成的。
网络的搭建当然要从tools/train.py开始,整个main函数前面大部分都是在解析configs的配置并存到cfg对象,直到在196行开始终于开始用build_segmentor这个函数来建立模型。
mmsegmentation模型生成代码解析_第1张图片
追溯这个函数,可以找到mmseg/model文件夹下的builder.py,model文件中显然存放的是模型的结构文件,包括主干网、neck、检测头等,builder.py应该算作model的一个整体对外的接口。
mmsegmentation模型生成代码解析_第2张图片
在builder.py文件中,我们发现首先用Registry实例化MODELS,并且感觉像是继承了MMCV_MODELS,这个基类我们线猜测是MMLAB的模型库。然后将MODELS又传递给SEGMENTORS。
mmsegmentation模型生成代码解析_第3张图片
在build_segmentor中,SEGMENTORS使用build方法建立了模型,到目前为之,模型算子或者模块都没有显示出来,那核心就是这个注册表Registry类作了什么操作了。

mmsegmentation模型生成代码解析_第4张图片
我们找到Registry类,说明里面表明Registry是为了将字符串和类进行map,那懂了,Registry确实是注册表的意思。注册表是为了做什么的呢?注册表本质上是存储设置信息的一种数据库,说明Registry其实本质目的就是把config的信息传递到网络的类中。

mmsegmentation模型生成代码解析_第5张图片
我们看到的Registry的使用方法是如下这种形式,可以看到’model’是传入的name, MMCV_MODELS是传入的build_func,描述中可以看到,build_func如果没有给出,但是parent参数给出了,build_func会隐式继承parent传入的参数。

MODELS = Registry('models', parent=MMCV_MODELS)

到目前为之,我们还有两个问题没有搞清楚,第一个问题,这个注册表类是怎么生成模型的;第二个问题,父类注册表MMCV_MODELS里面又说了些啥。我们先解决第一个问题,看一下注册表类里面的build方法:

def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None
def build(self, *args, **kwargs):
        return self.build_func(*args, **kwargs, registry=self)

其他代码先不关注,只看build,我们发现build方法实际就是调用了build_func,而build_func实际就是你传入的父类注册表,现在两个问题又回到了一个问题,父类或者基类的这个注册表描述了什么。去找一下他的定义,我发现了这个父类已经到了/python3.6/site-packages/mmcv/cnn/builder.py 这个路径下了,说明我们已经接近他的核心部分了。进取以后,惊呆了,竟然是个环,没错,你没看错,又是Registry,只不过这次直接传入了build_func,而且将build_func传入的builid_model_froom_cfg同时定义好了。

MODELS = Registry('model', build_func=build_model_from_cfg)

我们主要要看一下,这个build_func在干些啥,非常清晰,这个函数直接输出的就是nn.modules,就是我们要的pytorch模型

def build_model_from_cfg(cfg, registry, default_args=None):
    """Build a PyTorch model from config dict(s). Different from
    ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a config
            dict or a list of config dicts. If cfg is a list, a
            the built modules will be wrapped with ``nn.Sequential``.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

回想前面的build方法(如下),传入的参数其实就是给build_model_from_cfg这个函数服务的,传入的主要是cfg,train_cfg和test_cfg,看起来应该是cfg参数是模型主参数,先做个大胆的推测,然后等待打脸(补充:回到train.py 你会发现,传入的是cfg.model,确实是模型主参数)~ 模型到底是砸建的呢?我们又可以看到,build_model_from_cfg函数里面出现了一个build_from_cfg,而且执行了一个for循环去遍历cfg,我们有理由相信这步就是为了形成模型的各个模块,查找一下build_from_cfg这个函数,这个又回到了Reigister那个类的文件中。

return SEGMENTORS.build(
        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

分析一下这个函数,忽略前面一堆if,首先将cfg传给了arg,果然习惯用arg的人也不少啊哈哈~然后将default_args传入到args中,然后从args中pop出‘type’这个key对应的value,例如‘EncodeDecoder’,再将这个value传给registry.get方法。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

谈起get方法就稍微有点复杂了,get其实实现了这个pop出来的value是一个什么样的任务,其中又用到了类的嵌套,我对这个一直没有搞清楚。Anyway,我们这步实现了提取对应模型的class

def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        scope, real_key = self.split_scope_key(key)
        if scope is None or scope == self._scope:
            # get from self
            if real_key in self._module_dict:
                return self._module_dict[real_key]
        else:
            # get from self._children
            if scope in self._children:
                return self._children[scope].get(real_key)
            else:
                # goto root
                parent = self.parent
                while parent.parent is not None:
                    parent = parent.parent
                return parent.get(key)

最后在build_from_cfg中用try方法实例化的这个类,从而生成了模型。如果我们去看对应的类的话,我们还会发现每个类的上面还有对应的装饰器方法,该装饰器方法会在实例化模型的过程中,将模型记录在对应的注册表类中。
mmsegmentation模型生成代码解析_第6张图片

你可能感兴趣的:(计算机视觉,python,深度学习,自动驾驶)