1. 概述
- mmdetection是由商汤开元的目标检测算法集成框架,特点是检测算法多、可扩展性强,可以说是目标检测领域绕不开的源码。阅读mmdetection源码有助于理解各个目标检测算法具体实现及如何集成。
- 本篇博客的主要内容是mmdetection如何完成一个网络模型的构建。
2. 源码讲解
- train.py完成网络训练,包含两个函数,
parse_args()
和main()
,前者从命令行中读取参数,后者一次完成四个工作,①从config文件及命令行参数中读取各种配置参数;②构建网络模型;③构建数据集;④使用数据集训练模型。本篇主要介绍如何构建网络模型,以faster_rcnn_r50_fpn_1x为例
- 2.1 train.py中,下述代码完成模型创建
from mmdet.models import build_detector
/*
* other code
*/
def main():
"""
cfg: config文件
"""
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
- 2.2 注意上述代码的
import
语句,是包引用,会先执行models/__init__.py
文件,该文件内容如下,由于__init__.py
中已经导入了build_detector
,因此上面的代码可直接从models
中导入build_detector
函数。
from .anchor_heads import *
from .backbones import *
from .bbox_heads import *
from .builder import (build_backbone, build_detector, build_head, build_loss,
build_neck, build_roi_extractor, build_shared_head)
from .detectors import *
from .losses import *
from .mask_heads import *
from .necks import *
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
from .roi_extractors import *
from .shared_heads import *
__all__ = [
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_shared_head', 'build_head', 'build_loss', 'build_detector'
]
- 2.3
models/__init__.py
中完成所有网络组件的注册, 这些组件是如何完成注册的?这里以anchor_heads
组件的注册为例。from .anchor_heads import *
也是包引用,会先执行/models/anchor_heads/__init__.py
,其内容如下,这里from import *
的方式会导入__all__
中所有内容
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
__all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead'
]
- 2.4 以
AnchorHead
为例,进入models/anchor_heads/anchor_head.py
,主要内容如下,首先从models/registry.py
中导入Registry
类对象HEADS
,利用python语言的函数装饰器
,将AnchorHead
类作为HEADS
的成员函数register_module
的参数传入。看到这里可能有点迷糊,往下看。
from ..registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module):
"""
code in class defination
"""
- 2.5 下面给出
models/registry.py
中的代码,上面使用的HEADS
是一个Registry
类对象。Registry
类是如何定义的?继续看。
from mmdet.utils import Registry
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
- 2.6 下面给出
mmdet/utils/registry.py
的主要内容(为了突出重点,删除了一些代码)。Registry
类中的__init__
函数定义了两个属性,一个是_name
,一个是_module_dict
字典。2.6中提到,将AnchorHead
类作为HEADS
的成员函数register_module
的参数传入,从下述代码看,传入HEADS
中的类会被加入HEADS
的属性_module_dict
字典中。
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False):
module_name = module_class.__name__
self._module_dict[module_name] = module_class
def register_module(self, cls=None, force=False):
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
obj_type = args.pop('type')
obj_cls = registry.get(obj_type)
return obj_cls(**args)
- 2.7 至此,可以看到,网络组件
AnchorHead
(被归类为anchor_heads中)已被注册至HEADS
注册器中。其他所有网络组件都是如此,再回到2.2,执行models/__init__.py
后,各个网络组件都会被注册至各自所属的注册器中。而所谓的注册器,实际上维护的是一个字典,字典中的键值对为网络组件名和对应类的定义。
- 2.8 回到2.1,
train.py
调用build_detector
时,先执行models/__init__.py
,注册所有网络组件,再完成该函数的调用,该函数定义如下。
from torch import nn
from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
def build(cfg, registry, default_args=None):
'''
:param cfg: cfg.model
:param registry: DETECTORS
:param default_args: dict(train_cfg=train_cfg, test_cfg=test_cfg)
:return:
'''
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
'''
:param cfg: cfg.model
:param train_cfg: cfg.train_cfg
:param test_cfg: cfg.test_cfg
:return:
'''
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
- 2.9
build_detector
函数调用build
函数,build
函数又调用2.6中提到的build_from_cfg
函数。该函数接受的参数为配置文件cfg,注册器DETECTOR,下面给出部分配置文件。在build_from_cfg
函数中,取出键为type
的值,这里是FastRCNN
,然后再从DETECTOR
中取出键为FastRCNN
的值FastRCNN
类,将cfg中的其它键值对作为参数解包赋值给FastRCNN类
,返回该类的类对象,至此完成网络全部组件的注册。有读者注意到,这个过程中,似乎没有注册NECKS
、BACKBONES
等网络组件,是这样吗?
model = dict(
type='FastRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
- 2.10 上面提到,
build_from_cfg
会返回FastRCNN类对象
,FastRCNN类
定义在models/detectors/faster_rcnn.py
,内容如下,类对象中每个属性实际都是一个字典。FasterRCNN类
调用其父类TwoStageDetector
的构造函数。
from ..registry import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
- 2.11
TwoStageDetectors
类定义在models/detectors/two_stage.py
中,内容如下。依据传入类的各个参数,分别调用各个网络组件的build
函数,完成各个组测后的网络组件的加载。
import torch
import torch.nn as nn
from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
from .test_mixins import BBoxTestMixin, MaskTestMixin, RPNTestMixin
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
"""Base class for two-stage detectors.
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
"""
def __init__(self,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)
if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
if mask_roi_extractor is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.share_roi_extractor = False
else:
self.share_roi_extractor = True
self.mask_roi_extractor = self.bbox_roi_extractor
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
总结
- 总结一下,创建网络的整体流程是,先注册所有网络组件,再根据config文件从各个注册器中加载网络组件。