小白科研笔记:深入理解mmDetection框架——网络架构

1. 前言

这篇博客图解mmDetection搭建检测器网络架构的代码流程。在讲解的过程中,还是会以3D目标检测框架SA-SSD做例子讨论。

2. 网络架构

2.1 理论上的说明

这段说明以单阶段3D目标检测框架SA-SSD为例。它的输入是 [ W 0 , L 0 , H 0 , C 0 ] [W_0,L_0,H_0,C_0] [W0,L0,H0,C0]的稀疏三维特征张量。SA-SSDbackboneneck,和detector三部分组成。它前向计算的图解如下所示:

小白科研笔记:深入理解mmDetection框架——网络架构_第1张图片
图1:以SA-SSD为例的单阶段和多阶段检测图解

SA-SSD是单阶段目标检测算法。如果这个算法扩展为多阶段,可以参考PointRCNN网络,添加一个新的环节,把3d检测框内的点云特征“收集”起来(这个收集过程记为ROI Pooling),用来回归一个更加精细的3d检测框。

2.2 搭建网络

mmDetection框架搭建网络的过程很好理解,主要是根据配置参数cfg调用函数build_detector

    model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

其中build_detector是根据配置参数cfg迭代拼接网络模块的。迭代依据是for cfg_ in cfg

def build_detector(cfg, train_cfg=None, test_cfg=None):
    from . import detectors
    return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))

def build(cfg, parrent=None, default_args=None):
    if isinstance(cfg, list):
        modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return _build_module(cfg, parrent, default_args)

其中cfg模型部分的配置文件是一个字典型变量:

model = dict(
    type='SingleStageDetector',
    backbone=dict(
        type='SimpleVoxel',
        num_input_features=4,
        use_norm=True,
        num_filters=[32, 64],
        with_distance=False
    ),
    neck=dict(
        type='SpMiddleFHD',
        output_shape=[40, 1600, 1408],
        num_input_features=4,
        num_hidden_features=64 * 5,
    ),
    bbox_head=dict(
        type='SSDRotateHead',
        num_class=1,
        num_output_filters=256,
        num_anchor_per_loc=2,
        use_sigmoid_cls=True,
        encode_rad_error_by_sin=True,
        use_direction_classifier=True,
        box_code_size=7,
    ),
    extra_head=dict(
        type='PSWarpHead',
        grid_offsets = (0., 40.),
        featmap_stride=.4,
        in_channels=256,
        num_class=1,
        num_parts=28,
    )
)

3. 结束语

mmDetection框架搭建网络的过程还是挺容易的。

你可能感兴趣的:(computer,vision论文代码分析)