mmdetection2.17自定义模型及修改子模块

mmdetection2.17将模型分为5种类型,本文将围绕这五种类型逐一介绍如何自定义修改:

  • backbone::用一个 FCN network(全卷积网络)去提取 featuremap, 比如:ResNet, MobileNet。
  • neck: 是 backbones 和 heads 模块的中间部分, 比如: FPN, HRFPN, NASFCOSFPN等。
  • head: 用于特定任务的模块,比如: bbox prediction 和 mask prediction。
  • roi extractor: 从featuremap中提取roi特征的模块,比如:RoI Pooling,RoI Align。
  • loss:在head中计算损失的模块,比如:FocalLoss, L1Loss, 和 GHMLoss。

添加一个新的backbone,以MobileNet为例:

  • 创建一个新的文件: mmdet/models/backbones/mobilenet.py。
import torch.nn as nn
from ..builder import BACKBONES
@BACKBONES.register_module()
class MobileNet(nn.Module):
    def __init__(self, arg1, arg2):
        pass
    def forward(self, x):  
        pass
  • 导入一个新的模块:将以下代码添加到:mmdet/models/backbones/__ init__.py
from .mobilenet import MobileNet

否则就添加如下命令,从而避免修改原始代码:

custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)
  • 在config文件里使用新的backbone
model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

添加一个新的neck

  • 以PAFPN为例,首先在mmdet/models/necks生成一个pafpn.py.文件
from ..builder import NECKS
@NECKS.register_module()
class PAFPN(nn.Module):
    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass
    def forward(self, inputs):
        pass
  • 导入模型,在mmdet/models/necks/init.py中添加:
from .pafpn import PAFPN

否则就添加如下命令,从而避免修改原始代码:

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn.py'],
    allow_failed_imports=False)
  • 修改config文件
neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

添加一个新的head:以 Double Head R-CNN为例

  • 新建一个文件:mmdet/models/roi_heads/bbox_heads/double_bbox_head.py.
  • 为了实现bbox头部,需要实现新模块的三个功能,如下所示:
from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead
@HEADS.register_module()
class DoubleConvFCBBoxHead(BBoxHead):
    r"""Bbox head used in Double-Head R-CNN

                                      /-> cls
                  /-> shared convs ->
                                      \-> reg
    roi features
                                      /-> cls
                  \-> shared fc    ->
                                      \-> reg
    """  # noqa: W605
    def __init__(self,
                 num_convs=0,
                 num_fcs=0,
                 conv_out_channels=1024,
                 fc_out_channels=1024,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        kwargs.setdefault('with_avg_pool', True)
        super(DoubleConvFCBBoxHead, self).__init__(**kwargs)
    def forward(self, x_cls, x_reg):
  • 第二,如果有必要可以建立一个新的RoI头。从StandardRoIHead继承新的DoubleHeadRoIHead,可以发现StandardRoIHead已经实现了以下功能。
import torch

from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
from ..builder import HEADS, build_head, build_roi_extractor
from .base_roi_head import BaseRoIHead
from .test_mixins import BBoxTestMixin, MaskTestMixin
@HEADS.register_module()
class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
    """Simplest base roi head including one bbox head and one mask head.
    """
    def init_assigner_sampler(self):
    def init_bbox_head(self, bbox_roi_extractor, bbox_head):
    def init_mask_head(self, mask_roi_extractor, mask_head):
    def forward_dummy(self, x, proposals):
    def forward_train(self,
                      x,
                      img_metas,
                      proposal_list,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):
    def _bbox_forward(self, x, rois):
    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
                            img_metas):
    def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
                            img_metas):
    def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
    def simple_test(self,
                    x,
                    proposal_list,
                    img_metas,
                    proposals=None,
                    rescale=False):
        """Test without augmentation."""
  • 在mmdet/models/roi_heads/double_roi_head.py中:
from ..builder import HEADS
from .standard_roi_head import StandardRoIHead
@HEADS.register_module()
class DoubleHeadRoIHead(StandardRoIHead):
    """RoI head for Double Head RCNN
    https://arxiv.org/abs/1904.06493
    """
    def __init__(self, reg_roi_scale_factor, **kwargs):
        super(DoubleHeadRoIHead, self).__init__(**kwargs)
        self.reg_roi_scale_factor = reg_roi_scale_factor
    def _bbox_forward(self, x, rois):
        bbox_cls_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs], rois)
        bbox_reg_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs],
            rois,
            roi_scale_factor=self.reg_roi_scale_factor)
        if self.with_shared_head:
            bbox_cls_feats = self.shared_head(bbox_cls_feats)
            bbox_reg_feats = self.shared_head(bbox_reg_feats)
        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            bbox_feats=bbox_cls_feats)
        return bbox_results
  • 最后,需要在中添加模块mmdet/models/bbox_heads/init_.py和mmdet/models/roi_heads/init_.py,于是相应的注册表可以找到并加载它们。
    或者,可以添加
custom_imports=dict(
    imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.bbox_heads.double_bbox_head'])
  • Double Head R-CNN的配置文件如下:
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
    roi_head=dict(
        type='DoubleHeadRoIHead',
        reg_roi_scale_factor=1.3,
        bbox_head=dict(
            _delete_=True,
            type='DoubleConvFCBBoxHead',
            num_convs=4,
            num_fcs=2,
            in_channels=256,
            conv_out_channels=1024,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))

添加一个新的loss:假设想要添加一个新的损失MyLoss用于边界框回归,要添加新的损耗函数,需要在mmdet/models/loss/my_loss.py中实现它。decorator weighted_loss允许对每个元素的损失进行加权。

import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss
@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox
  • 需要添加mmdet/models/losses/init.py.
from .my_loss import MyLoss, my_loss
  • 修改配置文件:
loss_bbox=dict(type='MyLoss', loss_weight=1.0))

你可能感兴趣的:(深度学习+torch专栏,深度学习,mmdetection2.17,mmdetection)