mmdetection之模型注册

mmdetection注册模型


开头声明,文章较长,代码教多。
train.py的开头中,已经开始注册必要的模块了

from mmcv import Config

from mmdet import __version__
from mmdet.datasets import build_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger,
                        set_random_seed)
from mmdet.models import build_detector

看mmdet文件夹下的__init__.py,以及datasets , apis , models 下的__init__.py文件,发现:
mmdet.__init__py:


from .backbones import *  # noqa: F401,F403 from .necks import *  #    noqa: F401,F403 from .roi_extractors import *  # noqa: F401,F403 from    .anchor_heads import *  # noqa: F401,F403 from .shared_heads import *    # noqa: F401,F403 from .bbox_heads import *  # noqa: F401,F403 from .mask_heads import *  # noqa: F401,F403 from .losses import *  #    noqa: F401,F403 from .detectors import *  # noqa: F401,F403 from    .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS,    HEADS,
                      LOSSES, DETECTORS) from .builder import (build_backbone, build_neck, build_roi_extractor,
                     build_shared_head, build_head, build_loss,
                     build_detector)
       __all__ = [    'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',    'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',    'build_shared_head', 'build_head', 'build_loss', 'build_detector' ]

这个文件,第一行,导入了backbone._init_.py,看一下里面内容:
mmdet.models.backbones._init_.py

from .resnet import ResNet, make_res_layer
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .hrnet import HRNet

__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG', 'HRNet']

这里又导入了resnet,resnext等几个卷积神经网络,那么以resnet为例,看一下里面都有啥
mmdet.models.backbones.resnet.py
第13.14行:

from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer

其中又从registry中导入BACKBONES,那么再来看看registry和他的BACKBONES
mmdet.registry.py:

from mmdet.utils import Registry
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

那么这个Registry又是何方神圣?意欲何为?看看去
mmdet.utils._init_.py:

from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg']

顺藤摸瓜,找到registry.py
mmdet.utils.registry.py
代码稍长,分两段看吧。只看主要代码,能帮助理解其机制的代码,删除部分不影响理解的代码,全文都是。
Registry:

import inspect
import mmcv
class Registry(object):
    def __init__(self, name):
        self._name = name
        self._module_dict = dict()
    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str
    @property
    def name(self):
        return self._name
    @property
    def module_dict(self):
        return self._module_dict
    def get(self, key):
        return self._module_dict.get(key, None)
    def _register_module(self, module_class):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class
    def register_module(self, cls):
        self._register_module(cls)
        return cls

这段代码呢,生成了一个字典,里面包含了模块名字,以后模块都要挂在这个名字下。此时我们反过头来再看registry.py中的代码,其实是生成了各个主要部分,并向外提供了接口。
这段代码到现在,暂时没有了下文,我们再来看build_from_cfg函数:

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_type = registry.get(obj_type)
        if obj_type is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif not inspect.isclass(obj_type):
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)

这段代码比较难弄,尤其是最后哪行。我们来分析一波吧,既然难懂,就先来看他在那里被调用的吧。回到tools.train.py:

...
def parse_args():
	...
    parser.add_argument('config', help='train config file path')
...
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    ...
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    ...

因为其有个cfg参数,而这build_detector是用到了cfg,可以算个线索(如果你使用IDE的话,可以看看build_from_cfg是被谁引用的,顺藤摸瓜,推荐)。再来看build_detector,文章第一个代码段train.py最后一行引入进来,在mmet.models里,而mmet.models.__init__py中,有

from .builder import (build_backbone, build_neck, build_roi_extractor,
                      build_shared_head, build_head, build_loss,
                      build_detector)

我们来看这build_detector具体内容吧mmet.models.builder.py:

from torch import nn
from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
                       LOSSES, DETECTORS)

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_backbone(cfg):
    return build(cfg, BACKBONES)

def build_neck(cfg):
    return build(cfg, NECKS)

def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)

def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)

def build_head(cfg):
    return build(cfg, HEADS)

def build_loss(cfg):
    return build(cfg, LOSSES)
    
def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

看到参数又传到build里,而cfg是个dict类型,所以又到了build_from_cfg,此刻我们来分析build_from_cfg:

def build_from_cfg(cfg, registry, default_args=None):
	...
    args = cfg.copy()
    obj_type = args.pop('type')
    ...
    return obj_type(**args)

再在你的配置文件里看到这个obj_type:
configs.faster_rcnn_r50_fpn_1x.py:

model = dict(
    type='FasterRCNN',

其实也就是执行了FasterRCNN(),那么,FasterRCNN又是从何而来呢?
答:在mmdet.models._init里,可以看到from .detectors import *这行代码,再来瞧瞧mmdet.models.detectors._init.py:

from .base import BaseDetector
from .single_stage import SingleStageDetector
from .two_stage import TwoStageDetector
from .rpn import RPN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
from .cascade_rcnn import CascadeRCNN
from .htc import HybridTaskCascade
from .retinanet import RetinaNet
from .fcos import FCOS
from .grid_rcnn import GridRCNN
from .mask_scoring_rcnn import MaskScoringRCNN

__all__ = [
    'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
    'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
    'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN'
]

可以看到这里注册了一大堆的模型,取出faster_rcnn来看,在mmdet.models.detectors.faster_rcnn.py里:

from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
    def __init__(self,
                 backbone,
                 rpn_head,
                 bbox_roi_extractor,
                 bbox_head,
                 train_cfg,
                 test_cfg,
                 neck=None,
                 shared_head=None,
                 pretrained=None):
        super(FasterRCNN, self).__init__(
            backbone=backbone,
            neck=neck,
            shared_head=shared_head,
            rpn_head=rpn_head,
            bbox_roi_extractor=bbox_roi_extractor,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained)

看到这里以TwoStageDetector作为其父类,看TwoStageDetector:

import torch
import torch.nn as nn

from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
    def __init__(self,
                 backbone,
                 neck=None,
                 shared_head=None,
                 rpn_head=None,
                 bbox_roi_extractor=None,
                 bbox_head=None,
                 mask_roi_extractor=None,
                 mask_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(TwoStageDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        if shared_head is not None:
            self.shared_head = builder.build_shared_head(shared_head)
        if rpn_head is not None:
            self.rpn_head = builder.build_head(rpn_head)
        if bbox_head is not None:
            self.bbox_roi_extractor = builder.build_roi_extractor(
                bbox_roi_extractor)
            self.bbox_head = builder.build_head(bbox_head)

        if mask_head is not None:
            if mask_roi_extractor is not None:
                self.mask_roi_extractor = builder.build_roi_extractor(
                    mask_roi_extractor)
                self.share_roi_extractor = False
            else:
                self.share_roi_extractor = True
                self.mask_roi_extractor = self.bbox_roi_extractor
            self.mask_head = builder.build_head(mask_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)
    @property
    def with_rpn(self):
        return hasattr(self, 'rpn_head') and self.rpn_head is not None

可以看到,在这里形成了整个模型。
码字太累,费时,下次不定时更新。

你可能感兴趣的:(mmdetection,mmdetection)