3DIoUMatch-PVRCNN 模型部分

论文地址:https://arxiv.org/abs/2012.04355v3
项目地址:https://github.com/THU17cyz/3DIoUMatch-PVRCNN

本篇论文的关键代码:

  • Dataset:KittiDatasetSSL
  • Model:PVRCNN_SSL_3DIOU(本篇文章介绍)

一、原始PVRCNN

  • PVRCNN的所有模块:
    MeanVFE, VoxelBackBone8xm, HeightCompressionm, VoxelSetAbstraction, BaseBEVBackbone, AnchorHeadSingle, PointHeadSimple, PVRCNNHead
  • 说明:
    后三个模块涉及损失计算,因此在半监督实验中,teacher和student操作略有不同。(teacher不会调用损失计算函数)
  • 说明2:
    首先记录一下PVRCNN网络。

1. 前向传播forward

AnchorHeadSingle【检测头网络】:
  • self.forward_ret_dict更新:用于损失计算。
    • Prediction:
      • cls_preds: [2, 200, 176, 18]
      • box_preds: [2, 200, 176, 42]
      • dir_cls_preds: [2, 200, 176, 12]
    • Target:
      • box_cls_labels: [2, 211200] // 211200=200*176*18/3
      • box_reg_targets: [2, 211200, 7]
      • reg_weights: [2, 211200]
  • data_dict更新:用于保存预测框,为获取Proposal做准备。
    • batch_cls_preds: [2, 211200, 3]
    • batch_box_preds: [2, 211200, 7]
PointHeadSimple【点云特征】:
  • self.forward_ret_dict更新:用于损失计算。
    • point_cls_preds: [4096, 1]
    • point_cls_labels: [4096]
  • data_dict更新:用于保存点云特征,为Refinement作准备。
    • point_cls_scores:[4096]
PVRCNNHead【检测头网络】:
  • self.forward_ret_dict更新:用于损失计算。
    • Prediction:
      • rcnn_cls, rcnn_reg
    • 其他信息:
      • rois, gt_of_rois, gt_iou_of_rois, roi_scores, roi_labels, reg_valid_mask, rcnn_cls_labels, gt_of_rois_src
  • data_dict更新:
    • rois, roi_scores, roi_labels: 根据batch_cls_preds, batch_box_preds和ground truth获得roi区域,用于计算最终框。(在nolabel数据中,需要剔除从gt获得roi的部分)
    • batch_cls_preds: [2, 128, 1]
    • batch_box_preds: [2, 128, 7]

2. 损失计算

三个模块的损失和

def get_training_loss(self):
    disp_dict = {}
    loss_rpn, tb_dict = self.dense_head.get_loss()
    loss_point, tb_dict = self.point_head.get_loss(tb_dict)
    loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)

    loss = loss_rpn + loss_point + loss_rcnn
    return loss, tb_dict, disp_dict

二、半监督PVRCNN (Training Mode)

1. 模型构建和加载

构建

__init__中实现。包括两个结构完全一样的模型

  • self.pv_rcnn:参数通过后向传播优化。
  • self.pv_rcnn_ema:参数从计算图中分离,通过指数滑动平均优化,在update_global_step中更新。
加载

load_params_from_file中实现。两个模块pv_rcnn和pv_rcnn_ema加载相同参数。

2. 主要流程

  • 网络输入:data_dict包含Labeled/Unlabeled的强/弱数据增强结果。
  • Teacher前向传播:网络输入为Unlabeled的弱数据增强数据,得到预测结果batch_cls_preds: [2, 100, 1],batch_box_preds [2, 100, 7]
  • Teacher后处理:根据上面两个预测结果,得到预测框pred_boxes [22, 7], pred_scores [22], pred_labels [22]作为pseudo box。
  • pseudo box处理:filter筛选,数据增强对齐,作为Unlabeled数据的真值。即batch_dict['gt_boxes'][1]更新为pseudo box。
  • pseudo box处理2:论文的LHS模型,应该与NMS起到类似作用。未细看。
  • Student前向传播
  • 损失计算:半监督损失计算有scalar参数,默认返回标量损失值(全监督设置),该训练中返回各batch分别的损失值。目的是对label/unlabel损失设置权重。

附: Teacher前向传播forward

disable_gt_roi_when_pseudo_labeling开关控制Teacher中独特的计算。
Student前向传播与原始PVRCNN一致。

AnchorHeadSingle【检测头网络】:
  • self.forward_ret_dict更新:用于损失计算。
    • Prediction:
      • cls_preds: [2, 200, 176, 18]
      • box_preds: [2, 200, 176, 42]
      • dir_cls_preds: [2, 200, 176, 12]
    • Target: 因为无需计算损失,所以这部分未存储。
  • data_dict更新:用于保存预测框,为获取Proposal做准备。
    • batch_cls_preds: [2, 211200, 3]
    • batch_box_preds: [2, 211200, 7]
PointHeadSimple【点云特征】:
  • self.forward_ret_dict更新:用于损失计算。
    • point_cls_preds: [4096, 1]
    • point_cls_labels: 因为无需计算损失,所以这部分未存储。
  • data_dict更新:用于保存点云特征,为Refinement作准备。
    • point_cls_scores:[4096]
PVRCNNHead【检测头网络】:
  • self.forward_ret_dict更新:用于损失计算。
    • Prediction:
      • rcnn_cls, rcnn_reg
    • 其他信息:
      • rois, roi_scores, roi_labels 因为无需计算损失,其他信息未存储。
  • data_dict更新:
    • rois, roi_scores, roi_labels: 根据batch_cls_preds, batch_box_preds获得roi区域,用于计算最终框。(剔除从gt获得roi的部分)
    • batch_cls_preds: [2, 128, 1]
    • batch_box_preds: [2, 128, 7]

你可能感兴趣的:(3D目标检测,1024程序员节)