不同于LSS、BEVDepth的bottom-up式,先进行深度估计,设计2D转3D的模块。DETR3D是一种3D转2D的top-down思路。先预设一系列预测框的查询向量object querys,利用它们生成3D reference point,将这些3D reference point 利用相机参数转换矩阵,投影回2D图像坐标,并根据他们在图像的位置去找到对应的图像特征,用图像特征和object querys做cross-attention,不断refine object querys。最后利用两个MLP分支分别输出分类预测结果与回归预测结果。正负样本则采用和DETR相同的二分图匹配,即根据最小cost在900个object querys中找到与GT数量最匹配的N个预测框。由于正负样本匹配以及object querys这种查询目标的方式与DETR类似,因此可以看成是DETR在3D的扩展。
①. 利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出(注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式)
②. 预设900个object_querys(类似于2D中的priobox先验框), 拆分object query为query和query_pos, 利用全连接将query_pos的维度由[900, 256]映射到[900, 3], 此时就获得了BEV空间3D reference point (x, y, z)的参考点。
③. 进入transformer decoder,共有6层decoder layer,其中在每层layer之中,令q=k=v=query,即所有的object query之间先做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。
④. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
⑤. 预测的3D reference point投影回2D中,可能无对应的点或者在当前相机下不可见,因此使用一个mask 表示3D reference point是否在当前相机位中。
⑥. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample(双线性插值)采样,获得与2D reference point对应的图像特征。
⑦. query作为attention权重,与图像特征进行cross-attention。
⑧. 用取到的特征去 refine(优化) 3D reference point,refine 的方式也非常简单粗暴,直接相加即可。
⑨. 利用全连接输出回归预测分支与分类预测分支
⑩. 匈牙利算法进行二分图匹配,获得正负样本,计算分类损失(focal loss)、回归损失(L1 loss)。
优点:
①. 只查询object query对应的特征,没有完整显式地表示出整个BEV, 节省了内存和计算量,速度更快。
缺点:
①. 由于3D向2D投影时利用的是3D reference point这一物体中心点去fpn特征图中寻找特征,因此当感受野不足时,找到的特征可能不全,因此在实际应用中较长的目标比如bus,可能预测框偏小。
②. 同个BEV网格上的 reference point投影回2D采样的图像特征是相同的,缺乏深度信息,reference point和图像特征是否匹配需要通过不断的隐式学习去迭代。
利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出:
注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式
def extract_img_feat(self, img, img_metas):
"""Extract features of images."""
B = img.size(0)
if img is not None:
input_shape = img.shape[-2:]
# update real input shape of each single img
for img_meta in img_metas:
img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
# 合并batch和N的维度
img = img.view(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img)
# Resnet输出1/4, 1/8, 1/16, 1/32尺度特征图
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
# FPN
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
拆分object query为query和query_pos, 利用全连接处理query_pos获得3D reference point。
class Detr3DTransformer(BaseModule):
def forward(self,
mlvl_feats,
query_embed,
reg_branches=None,
**kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats:2d图像特征
query_embed:object querys:[num_query, 256]
"""
assert query_embed is not None
bs = mlvl_feats[0].size(0)
# 将object query的位置编码和query拆分,query_pos: [900, 256], query:[900, 256]
query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
# query_pos: [900, 256] ---> [1, 900, 256]
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
# query: [900, 256] ---> [1, 900, 256]
query = query.unsqueeze(0).expand(bs, -1, -1)
# 全连接预测出在BEV空间的3D reference point offset坐标(x, y, z), query_pos: [1, 900, 256]--->[1, 900, 3]
reference_points = self.reference_points(query_pos)
# sigmoid约束offset到(0, 1)范围
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
# query: [1, 900, 256] ---> [900, 1, 256]
query = query.permute(1, 0, 2)
# query_pos: [1, 900, 256] ---> [900, 1, 256]
query_pos = query_pos.permute(1, 0, 2)
# decoder,inter_states: 3d参考点在fpn的采样特征, inter_references:修正后的3d参考点
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=mlvl_feats,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
共有6层decoder layer,其中在每层layer之中,所有的object query之间做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。object query再和图像特征之间做cross-attention。
此时Mutilhead Attention的参数q=k=v=query。该query为2.2中object query中的query。
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
①. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
②. 过滤越界点,获得满足条件的mask
③. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample采样,获得与2D reference point对应的图像特征。
④. query作为attention权重,与图像特征进行cross-attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
class Detr3DCrossAtten(BaseModule):
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
# query: [900, 1, 256]--->[1, 900, 256]
query = query.permute(1, 0, 2)
# 1, 900, 256
bs, num_query, _ = query.size()
# 全连接, query: [1, 900, 256] --->[1, 900, 24]--->[1, 1, 900, 6, 1, 4]
attention_weights = self.attention_weights(query).view(bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
# 将3D 参考点利用转换矩阵转换到2D图像坐标系,通过2D reference_points的位置信息利用grid_sample去fpn输出特征层采样
reference_points_3d, output, mask = feature_sampling(value, reference_points, self.pc_range, kwargs['img_metas'])
# 替换nan的值为0
output = torch.nan_to_num(output)
mask = torch.nan_to_num(mask)
# 保存满足边界条件的attention_weights
attention_weights = attention_weights.sigmoid() * mask
# query与图像特征进行cross-attention
output = output * attention_weights
# output: [1, 256, 900, 6, 1, 4]--->[1, 256, 900]
output = output.sum(-1).sum(-1).sum(-1)
# output: [900, 1, 256]
output = output.permute(2, 0, 1)
# 全连接 output: [900, 1, 256]
output = self.output_proj(output)
# pos_feat:[900, 1, 256]
pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)
return self.dropout(output) + inp_residual + pos_feat
3D转2D + 采样代码
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
lidar2img = []
# 获得雷达坐标系转换矩阵
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points_3d = reference_points.clone()
# 从[-51.2, 51.2]转换为[0, 51.2]
# offset_x*(51.2+51.2)-51.2
reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
# offset_y*(51.2+51.2)-51.2
reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
# offset_z*(3+5)-5
reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
# reference_points: [1, 900, 4]
reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
B, num_query = reference_points.size()[:2]
num_cam = lidar2img.size(1)
# reference_points: [1, 900, 4]--->[1, 1, 900, 4]--->[1, 6, 900, 4, 1]
reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
# lidar2img: [1, 6, 4, 4]--->[1, 6, 1, 4, 4]--->[1, 6, 900, 4, 4]
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
# 左乘转换矩阵转图片坐标, reference_points_cam: [1, 6, 900, 4, 1]--->[1, 6, 900, 4]
reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
eps = 1e-5
# 过滤offset_z > 0
mask = (reference_points_cam[..., 2:3] > eps)
# 除以深度坐标Zc,获得二维坐标(x,y)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
# 归一化
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
# 中心点转换为顶点
reference_points_cam = (reference_points_cam - 0.5) * 2
# 过滤越界
mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0))
# mask: [1, 6, 900, 1]---> [1, 6, 1, 900, 1, 1]--->[1, 1, 900, 6, 1, 1]
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
mask = torch.nan_to_num(mask)
sampled_feats = []
for lvl, feat in enumerate(mlvl_feats):
B, N, C, H, W = feat.size()
feat = feat.view(B*N, C, H, W)
# reference_points_cam_lvl: [1, 6, 900, 2] --->[6, 900, 1, 2]
reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
# 利用F.grid_sample去feat上采样指定位置的点,feat: [6, 256, 116, 200], reference_points_cam_lvl:[6, 900, 1, 2], sampled_feat: [6, 256, 900, 1]
sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
# sampled_feat: [6, 256, 900, 1]--->[1, 6, 256, 900, 1]--->[1, 256, 900, 6, 1]
sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
# 保存每个特征层采样结果
sampled_feats.append(sampled_feat)
# 在最后一维stack, sampled_feats: [1, 256, 900, 6, 1, 4]
sampled_feats = torch.stack(sampled_feats, -1)
# sampled_feats: [1, 256, 900, 6, 1, 4]--->[1, 256, 900, 6, 1, 4]
sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats))
return reference_points_3d, sampled_feats, mask
和DETR类似,在所有object queries预测出来的预测框和所有GT box之间利用匈牙利算法进行二分图匹配,找到使得cost最小的最优匹配。
①. 分别计算分类cost:focal loss, 回归cost:l1 loss, 两者的和作为cost代价矩阵
②. 利用二分图匹配获得使cost最小的匹配结果
def assign(self,
bbox_pred,
cls_pred,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
eps=1e-7):
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
# 1. 将GT 索引和类别索引初始化为-1, assigned_gt_inds:[900], assigned_labels: [900]
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)
assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)
# No ground truth or boxes, return empty assignment
if num_gts == 0 or num_bboxes == 0:
if num_gts == 0:
# 没有GT,则全部声明为背景0
assigned_gt_inds[:] = 0
return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
# 2. 分别计算分类cost与回归cost
# 分类cost: focal loss
cls_cost = self.cls_cost(cls_pred, gt_labels)
# 归一化
normalized_gt_bboxes = normalize_bbox(gt_bboxes, self.pc_range)
# 回归cost: l1 cost
reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])
# weighted sum of above two costs
cost = cls_cost + reg_cost
# 3. 利用二分图匹配,获得最小cost匹配结果
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" ''to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(bbox_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(bbox_pred.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
分类损失: focal loss
回归损失: l1 loss