CenterPoint 在mmdetection3d中的实现

CenterPoint 在mmdetection3d中的实现

模型以如下配置文件为例:
configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py
MMDetection3d官方模型:
CenterPoint (继承自MVXTwoStageDetector)
该博客主要分析关键代码:
CenterHead

写在前面

# mmdet3d/models/detectors/centerpoint.py
class CenterPoint(MVXTwoStageDetector):
    """Base class of Multi-modality VoxelNet."""
    ...
    def forward_pts_train(self,
                          pts_feats,
                          gt_bboxes_3d,
                          gt_labels_3d,
                          img_metas,
                          gt_bboxes_ignore=None):
       
        outs = self.pts_bbox_head(pts_feats)
        loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
        losses = self.pts_bbox_head.loss(*loss_inputs)
        return losses

此处的self.pts_bbox_head,在配置文件中设置为CenterHead
因此,主要分析CenterHead中的forward函数和loss函数。

一、CenterHeadforward函数

【待补充】

二、CenterHeadloss函数

Step0: 参数说明

# mmdet3d/models/dense_heads/centerpoint_head.py
 def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
参数 gt_bboxes_3d gt_labels_3d preds_dicts
说明 保存真值:框的参数 保存真值:框的类别 forward函数的输出
类型 list[:obj:LiDARInstance3DBoxes] list[torch.Tensor] dict
备注 列表长度表示 batch_size 列表长度表示 batch_size 包含6个元素,分别是6个task的预测结果
举例 元素举例:tensor([0, 9, 0 …], device=‘cuda:0’) 将在后续详细说明

Step1: 对真值进行处理

1. 简述:

根据gt_bboxes_3dgt_labels_3d,生成各task的热图、框尺寸等信息。

2. loss函数中的实现:

# mmdet3d/models/dense_heads/centerpoint_head.py
# loss function
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)

3. 关键函数:

# mmdet3d/models/dense_heads/centerpoint_head.py
 def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
3-1 参数说明
参数 gt_bboxes_3d gt_labels_3d
说明 保存真值:框的参数 保存真值:框的类别
类型 obj:LiDARInstance3DBoxes torch.Tensor
取值举例 经处理后,可得到实际框的参数 tensor([0, 9, 0 …], device=‘cuda:0’)
尺寸举例 经处理后,torch.Size([76, 9]) torch.Size([76])
3-2 中间变量说明

根据配置文件中的大类,一共有6个task

  • task_masks
    按照类别划分,记录各类别目标在gt_bboxes_3d/gt_labels_3d中的坐标ID:
    CenterPoint 在mmdetection3d中的实现_第1张图片 CenterPoint 在mmdetection3d中的实现_第2张图片
  • task_boxes
    记录各task中的真实框参数。
  • task_classes
    重新排序各task中的真实框类别。(0是背景)
    CenterPoint 在mmdetection3d中的实现_第3张图片 CenterPoint 在mmdetection3d中的实现_第4张图片
3-3 针对每个Task,生成heatmapanno_boxindmask
参数 heatmap anno_box ind mask
说明 中心点热图 框的参数 框的中心点在heatmap中的位置 前obj_num个元素置1,obj_num表示框的个数
尺寸 [class_num, 128, 128] [500, 10] [500] [500]
取值举例 每个class有一张热图 10维参数的含义,见下 ind[idx] = x*128 + y mask[idx] = 1

遍历该Task内的所有目标,更新上述四个变量。

3-3-1 heatmap的更新
draw_gaussian(heatmap[cls_id], center_int, radius)

参数说明

  • cls_id 决定在哪一张热图上更新
  • center_int 记录中心点在热图上的位置(x, y)
  • radius 决定高斯核大小

结果举例
CenterPoint 在mmdetection3d中的实现_第5张图片

3-3-2 anno_box的更新
anno_box[new_idx] = torch.cat([
                        center - torch.tensor([x, y], device=device),
                        z.unsqueeze(0), box_dim,
                        torch.sin(rot).unsqueeze(0),
                        torch.cos(rot).unsqueeze(0),
                        vx.unsqueeze(0),
                        vy.unsqueeze(0)
                    ])
  1. 第1-2维表示中心点的偏移量offset_x offset_y
    热图上的坐标(x, y)是离散整型,实际的中心点有精确到小数的偏移。
  2. 第3维表示中心点的高度z
  3. 第4-6维表示目标框的长宽高box_dim
  4. 第7-8维表示旋转角度sin(α) cos(α)
  5. 第9-10维表示速度vx vy
    nuScenes数据集有速度数据,如需使用KITTI数据集,需要更改部分代码。
3-4 返回值说明

heatmaps, anno_boxes, inds, masks 均是长度为6的数组,保存6个task的内容。

至此,“Step1: 对真值进行处理“ 已经完成。

Step2: 损失值计算

1. 简述:

根据preds_dicts和Step1得到的heatmaps, anno_boxes, inds, masks ,分别计算每一个task的loss_heatmaploss_bbox

2. preds_dicts说明:

CenterPoint 在mmdetection3d中的实现_第6张图片
preds_dicts包含6个元素,分别是6个task的预测结果。下表以preds_dicts[0]举例:

preds_dicts[0] KEY dim heatmap height reg rot vel
尺寸 [batch_size, 3, 128, 128] [batch_size, 1, 128, 128] [batch_size, 1, 128, 128] [batch_size, 2, 128, 128] [batch_size, 2, 128, 128] [batch_size, 2, 128, 128]
说明 表示目标框的长宽高 热图 表示中心点的高度 表示中心点的偏移量 表示旋转角度 表示速度
对应Step1得到的真值 anno_box第4-6维box_dim heatmap anno_box第3维z anno_box第1-2维offset_x offset_y anno_box第7-8维sin(α) cos(α) anno_box第9-10维vx vy

3. 计算loss_heatmap:GaussianFocalLoss

loss_heatmap = self.loss_cls(
                preds_dict[0]['heatmap'],	# 预测得到的热图 [BS, cls_num, 128, 128]
                heatmaps[task_id],			# 实际的热图 [BS, cls_num, 128, 128]
                avg_factor=max(num_pos, 1))	# num_pos表示实际目标的数量

此处self.loss_cls是GaussianFocalLoss,该损失函数的实现见:
mmdetection/mmdet/models/losses/gaussian_focal_loss.py

4. 计算loss_bbox

loss_bbox = self.loss_bbox(
                pred,					# 预测 torch.Size([BS, 500, 10])
                target_box,				# 真值 torch.Size([BS, 500, 10])
                bbox_weights,			# torch.Size([BS, 500, 10]) 第2维表示mask 第3维表示权重
                avg_factor=(num + 1e-4))

此处self.loss_bbox是L1Loss,该损失函数的实现见:
mmdetection/mmdet/models/losses/smooth_l1_loss.py

Step3: 返回值说明【总结】

最终得到的loss_dict举例:

task0.loss_heatmap: 1.0833, task0.loss_bbox: 0.5410, 
task1.loss_heatmap: 1.2952, task1.loss_bbox: 0.5907, 
task2.loss_heatmap: 1.1385, task2.loss_bbox: 0.5840, 
task3.loss_heatmap: 1.0866, task3.loss_bbox: 0.4793, 
task4.loss_heatmap: 1.1322, task4.loss_bbox: 0.5697, 
task5.loss_heatmap: 1.2827, task5.loss_bbox: 0.6070

你可能感兴趣的:(3D目标检测)