mmdetection v1.2源码阅读笔记一:网络各组件注册详解

1. 概述

  • mmdetection是由商汤开元的目标检测算法集成框架,特点是检测算法多、可扩展性强,可以说是目标检测领域绕不开的源码。阅读mmdetection源码有助于理解各个目标检测算法具体实现及如何集成。
  • 本篇博客的主要内容是mmdetection如何完成一个网络模型的构建。

2. 源码讲解

  • train.py完成网络训练,包含两个函数,parse_args()main(),前者从命令行中读取参数,后者一次完成四个工作,①从config文件及命令行参数中读取各种配置参数;②构建网络模型;③构建数据集;④使用数据集训练模型。本篇主要介绍如何构建网络模型,以faster_rcnn_r50_fpn_1x为例
  • 2.1 train.py中,下述代码完成模型创建
from mmdet.models import build_detector
/*
 * other code
*/
def main():
    """
    cfg: config文件
    """
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
  • 2.2 注意上述代码的import语句,是包引用,会先执行models/__init__.py文件,该文件内容如下,由于__init__.py中已经导入了build_detector,因此上面的代码可直接从models中导入build_detector函数。
from .anchor_heads import *  # noqa: F401,F403
from .backbones import *  # noqa: F401,F403
from .bbox_heads import *  # noqa: F401,F403
from .builder import (build_backbone, build_detector, build_head, build_loss,
                      build_neck, build_roi_extractor, build_shared_head)
from .detectors import *  # noqa: F401,F403
from .losses import *  # noqa: F401,F403
from .mask_heads import *  # noqa: F401,F403
from .necks import *  # noqa: F401,F403
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                       ROI_EXTRACTORS, SHARED_HEADS)
from .roi_extractors import *  # noqa: F401,F403
from .shared_heads import *  # noqa: F401,F403
__all__ = [
    'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
    'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
    'build_shared_head', 'build_head', 'build_loss', 'build_detector'
]
  • 2.3 models/__init__.py中完成所有网络组件的注册, 这些组件是如何完成注册的?这里以anchor_heads组件的注册为例。from .anchor_heads import *也是包引用,会先执行/models/anchor_heads/__init__.py,其内容如下,这里from import *的方式会导入__all__中所有内容
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead

__all__ = [
    'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
    'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
    'ATSSHead'
]
  • 2.4 以AnchorHead为例,进入models/anchor_heads/anchor_head.py,主要内容如下,首先从models/registry.py中导入Registry类对象HEADS,利用python语言的函数装饰器,将AnchorHead类作为HEADS的成员函数register_module的参数传入。看到这里可能有点迷糊,往下看。
from ..registry import HEADS


@HEADS.register_module
class AnchorHead(nn.Module):
    """
    code in class defination
    """
  • 2.5 下面给出models/registry.py中的代码,上面使用的HEADS是一个Registry类对象。Registry类是如何定义的?继续看。
from mmdet.utils import Registry

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
  • 2.6 下面给出mmdet/utils/registry.py的主要内容(为了突出重点,删除了一些代码)。Registry类中的__init__函数定义了两个属性,一个是_name,一个是_module_dict字典。2.6中提到,将AnchorHead类作为HEADS的成员函数register_module的参数传入,从下述代码看,传入HEADS中的类会被加入HEADS的属性_module_dict字典中。
class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, force=False):
        module_name = module_class.__name__
        self._module_dict[module_name] = module_class

    def register_module(self, cls=None, force=False):
        if cls is None:
            return partial(self.register_module, force=force)
        self._register_module(cls, force=force)
        return cls


def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()
    obj_type = args.pop('type')
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)
  • 2.7 至此,可以看到,网络组件AnchorHead(被归类为anchor_heads中)已被注册至HEADS注册器中。其他所有网络组件都是如此,再回到2.2,执行models/__init__.py后,各个网络组件都会被注册至各自所属的注册器中。而所谓的注册器,实际上维护的是一个字典,字典中的键值对为网络组件名和对应类的定义。
  • 2.8 回到2.1,train.py调用build_detector时,先执行models/__init__.py,注册所有网络组件,再完成该函数的调用,该函数定义如下。
from torch import nn

from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                       ROI_EXTRACTORS, SHARED_HEADS)  # 这些注册器是已经被注册好的


def build(cfg, registry, default_args=None):
    '''
    :param cfg: cfg.model
    :param registry: DETECTORS
    :param default_args: dict(train_cfg=train_cfg, test_cfg=test_cfg)
    :return:
    '''
    if isinstance(cfg, list):  # 若cfg是list,采用这种方式构建网络
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        # 此时registry像一个仓库,依据cfg配置文件
        # 从仓库中取组件,再利用初始化参数初始化
        # 网络组件
        return build_from_cfg(cfg, registry, default_args)



def build_backbone(cfg):
    return build(cfg, BACKBONES)


def build_neck(cfg):
    return build(cfg, NECKS)


def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)


def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)


def build_head(cfg):
    return build(cfg, HEADS)


def build_loss(cfg):
    return build(cfg, LOSSES)


# 调用build_detector函数时,会先执行顶部的import,创建DETECTORS等注册器
def build_detector(cfg, train_cfg=None, test_cfg=None):
    '''
    :param cfg: cfg.model
    :param train_cfg: cfg.train_cfg
    :param test_cfg: cfg.test_cfg
    :return:
    '''
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
  • 2.9 build_detector函数调用build函数,build函数又调用2.6中提到的build_from_cfg函数。该函数接受的参数为配置文件cfg,注册器DETECTOR,下面给出部分配置文件。在build_from_cfg函数中,取出键为type的值,这里是FastRCNN,然后再从DETECTOR中取出键为FastRCNN的值FastRCNN类,将cfg中的其它键值对作为参数解包赋值给FastRCNN类,返回该类的类对象,至此完成网络全部组件的注册。有读者注意到,这个过程中,似乎没有注册NECKSBACKBONES等网络组件,是这样吗?
model = dict(
    type='FastRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    bbox_roi_extractor=dict(
        type='SingleRoIExtractor',
        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
        out_channels=256,
        featmap_strides=[4, 8, 16, 32]),
    bbox_head=dict(
        type='SharedFCBBoxHead',
        num_fcs=2,
        in_channels=256,
        fc_out_channels=1024,
        roi_feat_size=7,
        num_classes=81,
        target_means=[0., 0., 0., 0.],
        target_stds=[0.1, 0.1, 0.2, 0.2],
        reg_class_agnostic=False,
    loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
  • 2.10 上面提到,build_from_cfg会返回FastRCNN类对象FastRCNN类定义在models/detectors/faster_rcnn.py,内容如下,类对象中每个属性实际都是一个字典。FasterRCNN类调用其父类TwoStageDetector的构造函数。
from ..registry import DETECTORS
from .two_stage import TwoStageDetector


@DETECTORS.register_module  # 将FasterRCNN类对象作为Registry类的实例化对象DETECTORS
                            # 中register_module方法的参数传入
class FasterRCNN(TwoStageDetector):

    def __init__(self,
                 backbone,
                 rpn_head,
                 bbox_roi_extractor,
                 bbox_head,
                 train_cfg,
                 test_cfg,
                 neck=None,
                 shared_head=None,
                 pretrained=None):
        # 调用FasterRCNN的父类构造函数
        super(FasterRCNN, self).__init__(
            backbone=backbone,
            neck=neck,
            shared_head=shared_head,
            rpn_head=rpn_head,
            bbox_roi_extractor=bbox_roi_extractor,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained)
  • 2.11 TwoStageDetectors类定义在models/detectors/two_stage.py中,内容如下。依据传入类的各个参数,分别调用各个网络组件的build函数,完成各个组测后的网络组件的加载。
import torch
import torch.nn as nn

from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
from .test_mixins import BBoxTestMixin, MaskTestMixin, RPNTestMixin


@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
    """Base class for two-stage detectors.

    Two-stage detectors typically consisting of a region proposal network and a
    task-specific regression head.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 shared_head=None,
                 rpn_head=None,
                 bbox_roi_extractor=None,
                 bbox_head=None,
                 mask_roi_extractor=None,
                 mask_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(TwoStageDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)

        if neck is not None:
            self.neck = builder.build_neck(neck)

        if shared_head is not None:
            self.shared_head = builder.build_shared_head(shared_head)

        if rpn_head is not None:
            self.rpn_head = builder.build_head(rpn_head)

        if bbox_head is not None:
            self.bbox_roi_extractor = builder.build_roi_extractor(
                bbox_roi_extractor)
            self.bbox_head = builder.build_head(bbox_head)

        if mask_head is not None:
            if mask_roi_extractor is not None:
                self.mask_roi_extractor = builder.build_roi_extractor(
                    mask_roi_extractor)
                self.share_roi_extractor = False
            else:
                self.share_roi_extractor = True
                self.mask_roi_extractor = self.bbox_roi_extractor
            self.mask_head = builder.build_head(mask_head)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.init_weights(pretrained=pretrained)

总结

  • 总结一下,创建网络的整体流程是,先注册所有网络组件,再根据config文件从各个注册器中加载网络组件。

你可能感兴趣的:(#,mmdetection,python,深度学习)