实现在mmcv中,负责根据init_cfg对模块权值进行初始化
init_cfg: dict(type=‘xxx’, **kwargs)
初始化的方法也是用了Registry进行了管理,实现在mmcv/cnn/utils/weight_init.py
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab."""
def __init__(self, init_cfg=None):
...
self.init_cfg = init_cfg
def init_weights(self):
...
if not self._is_init:
if self.init_cfg:
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, (dict, ConfigDict)):
if self.init_cfg['type'] == 'Pretrained':
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
self._is_init = True
...
def initialize(module, init_cfg):
...
for cfg in init_cfg:
cp_cfg = copy.deepcopy(cfg)
override = cp_cfg.pop('override', None)
_initialize(module, cp_cfg)
....
def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func.wholemodule = wholemodule
func(module)
# mmcv/cnn/utils/weight_init.py
INITIALIZERS = Registry('initializer')
...
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
...
@INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
...
@INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
...
@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
...
model = dict(
type='RepPointsDetector',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5,
norm_cfg=norm_cfg, # add
),
bbox_head=dict(
norm_cfg=norm_cfg, # add
type='RepPointsHead',
num_classes=80,
in_channels=256,
feat_channels=256,
point_feat_channels=256,
stacked_convs=3,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=2, # 4,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.5),
loss_bbox_refine=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0),
transform_method='moment',
),
# training and testing settings
train_cfg=dict(
init=dict(
assigner=dict(type='PointAssigner', scale=2, pos_num=1), # scale=4
allowed_border=-1,
pos_weight=-1,
debug=False),
refine=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
optimizer = dict(lr=0.01)
RepPoint的继承链如下所示:
由于RepPoint对于Detector的修改都在Head上,因此Detector本身的代码结构并没有差异,只是给SingleStageDetector包了一层皮。
@DETECTORS.register_module()
class RepPointsDetector(SingleStageDetector):
"""RepPoints: Point Set Representation for Object Detection."""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(RepPointsDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained, init_cfg)
我们先从基本的父类开始说起,对于BaseModule可以前文,它就是mmdetection中所有moulde的基类,就是在nn.Module基础上实现了根据配置对参数进行初始化的操作,这里不再赘述。
BaseDetector的调用关系如下图所示:
class BaseDetector(BaseModule, metaclass=ABCMeta):
"""Base class for detectors."""
def __init__(self, init_cfg=None):
super(BaseDetector, self).__init__(init_cfg)
self.fp16_enabled = False
...
@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
...
async def async_simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError
async def aforward_test(self, *, img, img_metas, **kwargs):
...
if num_augs == 1:
return await self.async_simple_test(img[0], img_metas[0], **kwargs)
...
@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass
def forward_test(self, imgs, img_metas, **kwargs):
...
num_augs = len(imgs)
...
if num_augs == 1:
...
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
...
return self.aug_test(imgs, img_metas, **kwargs)
def forward_train(self, imgs, img_metas, **kwargs):
batch_input_shape = tuple(imgs[0].size()[-2:])
for img_meta in img_metas:
img_meta['batch_input_shape'] = batch_input_shape
@auto_fp16(apply_to=('img', ))
def forward(self, img, img_metas, return_loss=True, **kwargs):
...
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def _parse_losses(self, losses):
...
def train_step(self, data, optimizer):
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
def val_step(self, data, optimizer):
...
def show_result(
...
def onnx_export(self, img, img_metas):
raise NotImplementedError(f'{self.__class__.__name__} does '
f'not support ONNX EXPORT')
SingleStateDetector继承自BaseDetector,实现了上面说的需要实现的四个函数以及onnx_export。由此可见
@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
"""Base class for single-stage detectors."""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
def simple_test(self, img, img_metas, rescale=False):
feat = self.extract_feat(img)
results_list = self.bbox_head.simple_test(
feat, img_metas, rescale=rescale)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in results_list
]
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
...
feats = self.extract_feats(imgs)
results_list = self.bbox_head.aug_test(
feats, img_metas, rescale=rescale)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in results_list
]
return bbox_results
def onnx_export(self, img, img_metas):
"""Test function without test time augmentation."""
...
return det_bboxes, det_labels