本文主要是自己在阅读DETR3D的源码时的一个记录,如有错误或者问题,欢迎指正
在projects\mmdet3d_plugin\models\detectors\detr3d.py的forward_train()中,首先通过res50和FPN来进行图片特征的提取
img_feats = self.extract_feat(img=img, img_metas=img_metas)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
提取到的img_feats为 [num_level,bs,6,c,h,w]
然后调用self.forward_pts_train,进入到self.forward_pts_train中,首先调用self.pts_bbox_head来计算前向过程的输出,代码由此进入到detr3d_head中
hs, init_reference, inter_references = self.transformer(
mlvl_feats, #经过resnet和FPN提取到的多尺度特征
query_embeds, #[900,512] [num_query,embed_dims*2]
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 reg_banches是回归分支
img_metas=img_metas,
)
DETR3D这里的transformer只有decoder,没有encoder,整个transformer的代码如下:
assert query_embed is not None
bs = mlvl_feats[0].size(0)
# 首先将query_embed分为query 和 query_pos
query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) # [1, 900, 256] 1是batch_size
query = query.unsqueeze(0).expand(bs, -1, -1) # [1, 900, 256] 1是batch_size
# 通过linear和sigmoid从query_pos中获取到reference_points
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid() #[1, 900, 3] [bs,num_query,3]
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2) #[1,900,256] --> [900, 1, 256] [num_query,bs,256]
query_pos = query_pos.permute(1, 0, 2) #[1,900,256] --> [900, 1, 256][num_query,bs,256]
# 进入到Detr3DTransformerDecoder
inter_states, inter_references = self.decoder(
query=query, # [num_query,bs,256]
key=None,
value=mlvl_feats, #value就是提取出的图片特征 [num_level,bs,num_cam,c,h,w]
query_pos=query_pos, # [num_query,bs,256]
reference_points=reference_points, #[bs,num_query,3]
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
#inter_states是sample到的feature inter_references_out是更新后的referencepoints
return inter_states, init_reference_out, inter_references_out
decoder中先做self_attn,此时QKV都是query,shape为[900,bs,265],然后做cross_atten,在cross_attn中
if key is None:
key = query # key就等于query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
key和query是一样的,value是多尺度的feature map,虽然这里有了key,但是其实也没有用K乘Q去计算attention_weight,其attention_weight依然是通过query出的。
query = query.permute(1, 0, 2) # (1,900,256)
bs, num_query, _ = query.size() #bs=1, num_query=900
# (1,1,900,12,1,4) num_cams=12 num_points=1 num_levels=4
attention_weights = self.attention_weights(query).view(
bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
# 返回值reference_points_3d就是原来的3d坐标,output是sampled_feats
# output=B, C, num_query, num_cam, 1, len(mlvl_feats)] reference_points_3d=[1,900,3]
reference_points_3d, output, mask = feature_sampling(
value, reference_points, self.pc_range, kwargs['img_metas'])
output = torch.nan_to_num(output)
mask = torch.nan_to_num(mask)
attention_weights = attention_weights.sigmoid() * mask
# 个人理解:这里的output就是attention中的value
output = output * attention_weights
# 连续三个sum(-1),将不同尺度和不同相机的feature求和,得到最终图像的特征
output = output.sum(-1).sum(-1).sum(-1)
output = output.permute(2, 0, 1)
# 图像特征project到与query同维度 [bs,256,num_query,num_cam,1,num_level] --> [num_query, bs, 256]
output = self.output_proj(output)
# (num_query, bs, embed_dims)
pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) #MLP
# 输出 = sampled到的feature,原来的query,reference_points_3d的pos_feat
return self.dropout(output) + inp_residual + pos_feat
最重要的部分就是feature_sampling这个函数
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) (1,6,4,4)
reference_points = reference_points.clone()
reference_points_3d = reference_points.clone()
# 归一化坐标 pc_range =[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
# 将坐标从0-1尺度转到lidar坐标系下
reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0] #x轴
reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1] #y轴
reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2] #z轴
# reference_points (B, num_queries, 4) 在最后一列全加上1 变成(1,900,4)
reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
###############################################
# 2.由lidar系转化为camera系
###############################################
B, num_query = reference_points.size()[:2] #B=1,num_query=900
num_cam = lidar2img.size(1) # num_cam=6
# reference_points[1,900,4] --> reference_points.view(B, 1, num_query, 4) [1,1,900,4] --> repeat [1,12,900,4] --> [1,12,900,4,1]
reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) #[1, 12, 900, 4, 4]
# reference_points_cam.size() = [1,6,900,4]
reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
###############################################
# 3.由camera系转到图像系并归一化
###############################################
eps = 1e-5
# mask.size() = [1,6,900,1]
mask = (reference_points_cam[..., 2:3] > eps)
# 这一步是将坐标由camera系转到图像系 (x,y) = (xc,yc) / zc *f 这里的f是相机焦距,在前面lidar2img已经成过了,这里只用除以zc就行了
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps) # 深度归一化 (1, 6, 900, 2)
# 在img平面上进行长宽归一化
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] # 长宽归一化
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
#将坐标由[0,1] 转到[-1,1]之间
reference_points_cam = (reference_points_cam - 0.5) * 2
# 对所有不在grid内的点,也就是投影在某个cam之外的点进行mask
mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0))
# mask.size()=[1,1,900,6,1,1]
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
mask = torch.nan_to_num(mask)
sampled_feats = []
# 逐特征层sample feature
for lvl, feat in enumerate(mlvl_feats):
B, N, C, H, W = feat.size() #B=1,N=6 C=256 H=16 W=28
feat = feat.view(B*N, C, H, W)
reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
sampled_feats.append(sampled_feat)
sampled_feats = torch.stack(sampled_feats, -1)
sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats))
return reference_points_3d, sampled_feats, mask
在通过feature_sampling提取特征后,将得到的output首先和attention_weigth相乘,然后连续三个sum(-1),将不同相机,不同尺度的feature直接相加,将这些特征都融合再一起,再通过一个out_proj的线性层,将这些特征转换到与query相同的维度,最后的输出就是提取到的特征和原始的query以及reference_points3d的位置编码的和。
在经过每一层的decoder layer之后,会有一个回归分支来预测bbox,会根据预测出的bbox来更新reference point
整个Detr3DTransformerDecoder的代码如下:
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = reference_points
# 由此进入DetrTransformerDecoderLayer
# 返回的output为[900,1,256]
output = layer(
output,
*args,
reference_points=reference_points_input,
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
assert reference_points.shape[-1] == 3
new_reference_points = torch.zeros_like(reference_points)
# x y
new_reference_points[..., :2] = tmp[
..., :2] + inverse_sigmoid(reference_points[..., :2])
# z
new_reference_points[..., 2:3] = tmp[
..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
整个decoder部分返回的是每一个decoder layer输出的query和reference point
走完整个transformer之后,后面就是通过输出的每一层的feature和referencepoint来进行预测。