模型以如下配置文件为例:
configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py
MMDetection3d官方模型:
CenterPoint (继承自MVXTwoStageDetector)
该博客主要分析关键代码:
CenterHead
# mmdet3d/models/detectors/centerpoint.py
class CenterPoint(MVXTwoStageDetector):
"""Base class of Multi-modality VoxelNet."""
...
def forward_pts_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.pts_bbox_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.pts_bbox_head.loss(*loss_inputs)
return losses
此处的self.pts_bbox_head
,在配置文件中设置为CenterHead
。
因此,主要分析CenterHead
中的forward
函数和loss
函数。
CenterHead
的forward
函数【待补充】
CenterHead
的loss
函数# mmdet3d/models/dense_heads/centerpoint_head.py
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
参数 | gt_bboxes_3d | gt_labels_3d | preds_dicts |
---|---|---|---|
说明 | 保存真值:框的参数 | 保存真值:框的类别 | forward 函数的输出 |
类型 | list[:obj:LiDARInstance3DBoxes ] |
list[torch.Tensor] | dict |
备注 | 列表长度表示 batch_size | 列表长度表示 batch_size | 包含6个元素,分别是6个task的预测结果 |
举例 | 元素举例:tensor([0, 9, 0 …], device=‘cuda:0’) | 将在后续详细说明 |
根据gt_bboxes_3d
和gt_labels_3d
,生成各task的热图、框尺寸等信息。
loss
函数中的实现:# mmdet3d/models/dense_heads/centerpoint_head.py
# loss function
heatmaps, anno_boxes, inds, masks = self.get_targets(gt_bboxes_3d, gt_labels_3d)
# mmdet3d/models/dense_heads/centerpoint_head.py
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
参数 | gt_bboxes_3d | gt_labels_3d |
---|---|---|
说明 | 保存真值:框的参数 | 保存真值:框的类别 |
类型 | obj:LiDARInstance3DBoxes |
torch.Tensor |
取值举例 | 经处理后,可得到实际框的参数 | tensor([0, 9, 0 …], device=‘cuda:0’) |
尺寸举例 | 经处理后,torch.Size([76, 9]) | torch.Size([76]) |
根据配置文件中的大类,一共有6个task
gt_bboxes_3d
/gt_labels_3d
中的坐标ID:heatmap
、anno_box
、ind
、mask
参数 | heatmap | anno_box | ind | mask |
---|---|---|---|---|
说明 | 中心点热图 | 框的参数 | 框的中心点在heatmap 中的位置 |
前obj_num个元素置1,obj_num表示框的个数 |
尺寸 | [class_num, 128, 128] | [500, 10] | [500] | [500] |
取值举例 | 每个class有一张热图 | 10维参数的含义,见下 | ind[idx] = x*128 + y | mask[idx] = 1 |
遍历该Task内的所有目标,更新上述四个变量。
heatmap
的更新draw_gaussian(heatmap[cls_id], center_int, radius)
参数说明
anno_box
的更新anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
offset_x
offset_y
z
box_dim
sin(α)
cos(α)
vx
vy
heatmaps
, anno_boxes
, inds
, masks
均是长度为6的数组,保存6个task的内容。
至此,“Step1: 对真值进行处理“ 已经完成。
根据preds_dicts
和Step1得到的heatmaps
, anno_boxes
, inds
, masks
,分别计算每一个task的loss_heatmap
和loss_bbox
。
preds_dicts
说明:
preds_dicts
包含6个元素,分别是6个task的预测结果。下表以preds_dicts[0]
举例:
preds_dicts[0] KEY | dim | heatmap | height | reg | rot | vel |
---|---|---|---|---|---|---|
尺寸 | [batch_size, 3, 128, 128] | [batch_size, 1, 128, 128] | [batch_size, 1, 128, 128] | [batch_size, 2, 128, 128] | [batch_size, 2, 128, 128] | [batch_size, 2, 128, 128] |
说明 | 表示目标框的长宽高 | 热图 | 表示中心点的高度 | 表示中心点的偏移量 | 表示旋转角度 | 表示速度 |
对应Step1得到的真值 | anno_box 第4-6维box_dim |
heatmap |
anno_box 第3维z |
anno_box 第1-2维offset_x offset_y |
anno_box 第7-8维sin(α) cos(α) |
anno_box 第9-10维vx vy |
loss_heatmap
:GaussianFocalLossloss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # 预测得到的热图 [BS, cls_num, 128, 128]
heatmaps[task_id], # 实际的热图 [BS, cls_num, 128, 128]
avg_factor=max(num_pos, 1)) # num_pos表示实际目标的数量
此处self.loss_cls是GaussianFocalLoss,该损失函数的实现见:
mmdetection/mmdet/models/losses/gaussian_focal_loss.py
loss_bbox
:loss_bbox = self.loss_bbox(
pred, # 预测 torch.Size([BS, 500, 10])
target_box, # 真值 torch.Size([BS, 500, 10])
bbox_weights, # torch.Size([BS, 500, 10]) 第2维表示mask 第3维表示权重
avg_factor=(num + 1e-4))
此处self.loss_bbox是L1Loss,该损失函数的实现见:
mmdetection/mmdet/models/losses/smooth_l1_loss.py
最终得到的loss_dict
举例:
task0.loss_heatmap: 1.0833, task0.loss_bbox: 0.5410,
task1.loss_heatmap: 1.2952, task1.loss_bbox: 0.5907,
task2.loss_heatmap: 1.1385, task2.loss_bbox: 0.5840,
task3.loss_heatmap: 1.0866, task3.loss_bbox: 0.4793,
task4.loss_heatmap: 1.1322, task4.loss_bbox: 0.5697,
task5.loss_heatmap: 1.2827, task5.loss_bbox: 0.6070