MMDetection CenterNet 源码解析

文章目录

    • 0. 前言
    • 1. 模型构建
    • 2. `BaseDetector`
    • 3. `SingleStageDetector`
    • 4. `CenterNetHead`

0. 前言

  • CenterNet是我很喜欢的一篇论文,直观、好懂。然而,官方的 CenterNet 源码质量真的一般,看过的人应该都有这种感觉。

  • 好消息是,MMDetection 中复现了 CenterNet,可以参考这里

  • 此外,我想要复现时空行为检测中的 MOC-Detector,这篇文章也是基于 CenterNet 的,所以要捋一捋 CenterNet 源码。

  • MMDetection 的源码工程化非常好,但结构过于复杂,新手非常困难,老手如果长期不用估计也要忘。

  • 本文只关注总体结构,不关注一些具体细节(比如gaussian heatmap gt如何实现等)

1. 模型构建

  • MMDetection 中模型的构建主要包括:

    • 模型总体结构,管理模型的各个组件以及定义模型训练与测试是的前向流程。
      • CenterNet 中就是 CenterNet
      • 管理的组件包括 backbone/neck/bbox_head 三个部分
    • 模型细节,包括 backbone/neck/bbox_head 的前向细节
    • 模型训练相关,包括损失函数与GT的构建。
  • 模型总体结构是 CenterNet 类,其继承结构是

    • CenterNet -> SingleStageDetector -> BaseDetector -> BaseModule -> torch.nn.Module
    • BaseModule 源码位于 mmcv 中,与 torch.nn.Module 的区别在于,该类实现了 init_weights 功能。
    • CenterNet 主要就是重写了TTA测试,目前不关心,所以不写了。源码
    • 剩下几个类后面单独介绍
  • 除了总体结构外,就是CenterNet的几个基本组件

    • backbone,neck,head,分别是 ResNet/CTResNetNeck/CenterNetHead
    • ResNet 就是普通的残差网络,没啥好说的
    • CTResNetNeck 就是在ResNet后添加了若干 deconv 层,在Deconv前加了DCNv2。

2. BaseDetector

  • 源码在mmdet.models.detectors.base.py中,主要作用就是定义一些所有检测器都会用到的功能。

  • 抽象类,其他所有检测模型都会继承该对象。换句话说,检测的主要功能这个函数就全部定义完了,子类要做的就是实现里面的方法。

  • 功能主要可以划分为:

    • 判断组件是否存在,即 with_xxx 函数,其中,xxx 可以的取值有 neck/shared_head/bbox/mask
    • 定义模型前向的基本流程,根据 train/test 分别定义,后面会详细介绍这部分。
    • 定义模型训练、验证时的流程,如获取模型结果、计算损失函数,即 train_step/val_step,后面详细介绍这部分。
    • 展示模型结果,即在 img 上画 bbox 和 labels,show_result 函数
    • 模型导出为 ONNX 格式,即 onnx_export
  • 模型前向推理:

    • 入口函数是 forward

      • 为什么是入口?这个是 nn.Module 中定义的,__call__ 函数会调用 forward 函数。
      • 该函数根据模式分别调用 forward_trainforward_test 函数。
    • 训练时函数前向流程入口函数 forward_train

      • 这个函数一般会在子类中重写。
      • 该函数的结果一般就是各种损失函数
    • 验证/测试时前向流程入口 forward_test

      • 有TTA则调用 aug_test

      • 没有TTA则调用 simple_test

      • 上面两个函数都是抽象函数,子类继承实现

      • 该函数的结果一般是模型预测结果(经过后处理),而不包括损失函数

    • 定义了特征提取抽象函数 extract_featextract_feats,一般会在 forward_train/forward_test 的具体实现中引用特征提取这两个函数。

  • 模型训练、验证时的流程

    • 主要就是 train_stepval_step
    • 这一部分与前面 模型前向推理 的区别在于,在调用了模型前向推理函数(即 model(**data) )后,还会对模型结果进行一些封装。换句话说,就是对 forward_train 的结果进行一些封装。
    • 所谓封装,一般也就是调用 _parse_losses 函数,就是解析各种loss,封装成 dict 并累加求和
    • 每次训练、验证的时候就需要调用者两个函数,主要问题就在于,什么时候调用。
    • openmmlab 中,训练和测试都会使用 Runner 实现,而在Runner中就对调用者两个函数,如源码所示

3. SingleStageDetector

  • 源码在 mmdet.models.detectors.single_stage

  • 定义了所有单阶段目标检测器的基本功能与流程。

  • 细节上看,就是重写了 simple_test/aug_test/extract_feat/forward_train/forward_dummy/onnx_export 几个函数。

  • 特征提取流程:就是 backbone + neck,没有啥好说的。

  • 单阶段目标检测训练时流程:

    • 特征提取+head前向
    • 计算损失函数都是在 head 中定义的
  • 无TTA测试流程

    • 特征提取+head前向
    • head前向中就是获取bbox,没有其他损失函数相关
  • 有TTA 测试流程

    • 特征提取+head前向
    • head前向中实现tta
  • 从上面可以看到,MMDetection 中的 head 实现了很多功能,包括

    • 普通前向,获取预测结果
    • 训练时,GT 构建,与预测结构匹配,并计算损失函数
    • 测试时,对预测结果后处理(如NMS),获取最终结果
    • 处理 TTA 的细节
    • 从源码上看,head 需要有 forward_train 获取损失函数,get_bboxes 实现后处理获取检测框,aug_test 实现TTA
  • 其实 SingleStageDetector 类已经比较完善了,只要导入各种backbone、neck、head 就能实现单目标检测功能了。

4. CenterNetHead

  • CenterNet 实现的关键,主要功能包括:
    • 普通前向,获取模型预测结果。
    • 前向+后处理,获取过滤后的模型结果。
    • 前向+损失函数/获取GT等。
    • 前向+TTA
  • CenterNetHead 继承了 BaseDenseHeadBBoxTestMixin
    • BaseDenseHead 的主要功能就是定义了一个 head 应该做哪些工作
      • losses:计算损失函数
      • get_bboxes:根据模型结果获取 bboxes,包括后处理以及模型结果解析
      • forward_train:调用前面的 loss 函数,管理计算损失函数的过程
      • simple_test:定义基本前向过程
    • BBoxTestMixin 是个 Mixin 函数,不太懂,感觉就是包含一堆可复用的作为对象让别人来集成?
      • 主要就是 TTA 相关以及rpn相关
  • CenterNetHead 主要就是实现了 loss/forward/get_bboxes 两个函数
    • loss 的功能主要包括:
      • 根据 gt_bboxes/gt_labels 获取与预测结果一一对应的 GT。比如 heatmap 对应的就是符合高斯分布的圆。
      • 分别计算几个分支的损失函数。
    • forward 的主要功能包括:
      • 根据几个 head,以特征提取结果作为输入,获取模型最终预测结果
    • get_bboxes 的主要功能就是将模型预测结果转换为 bboxes
      • 大概就是解析 heatmap、进行NMS 等。

你可能感兴趣的:(CV,mmdetection,目标检测,CenterNet)