mmdetection代码阅读系列(四):RepPoint代码阅读(上)RepPointsDetector

RepPointsDetector
SingleStageDetector
BaseDetector
BaseModule
torch.nn.Module
BaseDetector
train_step
__call__/forward
_parse_losses
val_step
forward_train
forward_test
simple_test
aug_test
extract_feat
SingleStageDetector
forward_train
extract_feat
backbone
neck
simple_test
bbox_head.simple_test
aug_test
bbox_head.aug_test
bbox_head.forward_train
onnx_export

BaseModule

实现在mmcv中,负责根据init_cfg对模块权值进行初始化
init_cfg: dict(type=‘xxx’, **kwargs)
初始化的方法也是用了Registry进行了管理,实现在mmcv/cnn/utils/weight_init.py

class BaseModule(nn.Module, metaclass=ABCMeta):
    """Base module for all modules in openmmlab."""
    def __init__(self, init_cfg=None):
    	...
        self.init_cfg = init_cfg
        
    def init_weights(self):
    	...
        if not self._is_init:
            if self.init_cfg:
                initialize(self, self.init_cfg)
                if isinstance(self.init_cfg, (dict, ConfigDict)):
                    if self.init_cfg['type'] == 'Pretrained':
                        return

            for m in self.children():
                if hasattr(m, 'init_weights'):
                    m.init_weights()
            self._is_init = True
        ...
def initialize(module, init_cfg):
	...
    for cfg in init_cfg:
        cp_cfg = copy.deepcopy(cfg)
        override = cp_cfg.pop('override', None)
        _initialize(module, cp_cfg)
    ....

def _initialize(module, cfg, wholemodule=False):
    func = build_from_cfg(cfg, INITIALIZERS)
    # wholemodule flag is for override mode, there is no layer key in override
    # and initializer will give init values for the whole module with the name
    # in override.
    func.wholemodule = wholemodule
    func(module)
# mmcv/cnn/utils/weight_init.py
INITIALIZERS = Registry('initializer')
...
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
...
@INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
...
@INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
...
@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
...
model = dict(
    type='RepPointsDetector',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        add_extra_convs='on_input',
        num_outs=5,
        norm_cfg=norm_cfg,  # add
    ),
    bbox_head=dict(
        norm_cfg=norm_cfg,  # add
        type='RepPointsHead',
        num_classes=80,
        in_channels=256,
        feat_channels=256,
        point_feat_channels=256,
        stacked_convs=3,
        num_points=9,
        gradient_mul=0.1,
        point_strides=[8, 16, 32, 64, 128],
        point_base_scale=2,  # 4,
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox_init=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.5),
        loss_bbox_refine=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0),
        transform_method='moment',
    ),
    # training and testing settings
    train_cfg=dict(
        init=dict(
            assigner=dict(type='PointAssigner', scale=2, pos_num=1),  # scale=4
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        refine=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.4,
                min_pos_iou=0,
                ignore_iof_thr=-1),
            allowed_border=-1,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        nms_pre=1000,
        min_bbox_size=0,
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100))
optimizer = dict(lr=0.01)

Detector

RepPointsDetector

RepPoint的继承链如下所示:
由于RepPoint对于Detector的修改都在Head上,因此Detector本身的代码结构并没有差异,只是给SingleStageDetector包了一层皮。

RepPointsDetector
SingleStageDetector
BaseDetector
BaseModule
torch.nn.Module
@DETECTORS.register_module()
class RepPointsDetector(SingleStageDetector):
    """RepPoints: Point Set Representation for Object Detection."""
    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super(RepPointsDetector,
              self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
                             pretrained, init_cfg)

我们先从基本的父类开始说起,对于BaseModule可以前文,它就是mmdetection中所有moulde的基类,就是在nn.Module基础上实现了根据配置对参数进行初始化的操作,这里不再赘述。

BaseDetector

BaseDetector的调用关系如下图所示:

  • train_step / val_step在Runner中被调用(参考此文 Runner章节),可以视为运行时程序的入口。
  • 一般需要继承实现forward_train, simple_test, aug_test, extract_feat这四个函数。
  • 该类还实现了show_reuslts的可视化功能;以及定义了onnx_export接口,但并未实现。
BaseDetector
train_step
__call__/forward
_parse_losses
val_step
forward_train
forward_test
simple_test
aug_test
extract_feat
class BaseDetector(BaseModule, metaclass=ABCMeta):
    """Base class for detectors."""
    def __init__(self, init_cfg=None):
        super(BaseDetector, self).__init__(init_cfg)
        self.fp16_enabled = False
	...
    @abstractmethod
    def extract_feat(self, imgs):
        """Extract features from images."""
        pass
    ...
    
    async def async_simple_test(self, img, img_metas, **kwargs):
        raise NotImplementedError
    async def aforward_test(self, *, img, img_metas, **kwargs):
        ...
        if num_augs == 1:
            return await self.async_simple_test(img[0], img_metas[0], **kwargs)
        ...

    @abstractmethod
    def simple_test(self, img, img_metas, **kwargs):
        pass
    @abstractmethod
    def aug_test(self, imgs, img_metas, **kwargs):
        """Test function with test time augmentation."""
        pass
    def forward_test(self, imgs, img_metas, **kwargs):
    	...
        num_augs = len(imgs)
        ...
        if num_augs == 1:
			...
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            ...
            return self.aug_test(imgs, img_metas, **kwargs)
            
    def forward_train(self, imgs, img_metas, **kwargs):
        batch_input_shape = tuple(imgs[0].size()[-2:])
        for img_meta in img_metas:
            img_meta['batch_input_shape'] = batch_input_shape

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_metas, return_loss=True, **kwargs):
    	...
        if return_loss:
            return self.forward_train(img, img_metas, **kwargs)
        else:
            return self.forward_test(img, img_metas, **kwargs)

    def _parse_losses(self, losses):
        ...

    def train_step(self, data, optimizer):
        losses = self(**data)
        loss, log_vars = self._parse_losses(losses)

        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

        return outputs

    def val_step(self, data, optimizer):
        ...

    def show_result(
    	...

    def onnx_export(self, img, img_metas):
        raise NotImplementedError(f'{self.__class__.__name__} does '
                                  f'not support ONNX EXPORT')

SingleStateDetector

SingleStateDetector继承自BaseDetector,实现了上面说的需要实现的四个函数以及onnx_export。由此可见

  • DetectorDetector只是实现了backbone,neck, head之间的连接关系; 并没有backbone, neck, head的具体实现,只是利用Registry工厂模式根据cfg中创建(build)出他们;
  • 配置文件中的train_cfg和test_cfg会直接传递给head;
  • 对于forward_train,simple_test,simple_test,Detector中也没有进行核心实现,而是交给了head来实现。
SingleStageDetector
forward_train
extract_feat
backbone
neck
simple_test
bbox_head.simple_test
aug_test
bbox_head.aug_test
bbox_head.forward_train
onnx_export
@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
    """Base class for single-stage detectors."""
    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super(SingleStageDetector, self).__init__(init_cfg)
        backbone.pretrained = pretrained
        self.backbone = build_backbone(backbone)
        if neck is not None:
            self.neck = build_neck(neck)
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
        self.bbox_head = build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

    def extract_feat(self, img):
        """Directly extract features from the backbone+neck."""
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x

    def forward_dummy(self, img):
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        return outs

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):
        super(SingleStageDetector, self).forward_train(img, img_metas)
        x = self.extract_feat(img)
        losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)
        return losses

    def simple_test(self, img, img_metas, rescale=False):
        feat = self.extract_feat(img)
        results_list = self.bbox_head.simple_test(
            feat, img_metas, rescale=rescale)
        bbox_results = [
            bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            for det_bboxes, det_labels in results_list
        ]
        return bbox_results

    def aug_test(self, imgs, img_metas, rescale=False):
		...
        feats = self.extract_feats(imgs)
        results_list = self.bbox_head.aug_test(
            feats, img_metas, rescale=rescale)
        bbox_results = [
            bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            for det_bboxes, det_labels in results_list
        ]
        return bbox_results

    def onnx_export(self, img, img_metas):
        """Test function without test time augmentation."""
        ...
        return det_bboxes, det_labels

你可能感兴趣的:(mmdetection源码阅读)