DETR3D代码阅读

前言

本文主要是自己在阅读DETR3D的源码时的一个记录,如有错误或者问题,欢迎指正

提取feature map

在projects\mmdet3d_plugin\models\detectors\detr3d.py的forward_train()中,首先通过res50和FPN来进行图片特征的提取

        img_feats = self.extract_feat(img=img, img_metas=img_metas)
        losses = dict()
        losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
                                            gt_labels_3d, img_metas,
                                            gt_bboxes_ignore)
        losses.update(losses_pts)

提取到的img_feats为 [num_level,bs,6,c,h,w]
在这里插入图片描述
然后调用self.forward_pts_train,进入到self.forward_pts_train中,首先调用self.pts_bbox_head来计算前向过程的输出,代码由此进入到detr3d_head中

detr3d_head

进入transformer

  hs, init_reference, inter_references = self.transformer(
            mlvl_feats,  #经过resnet和FPN提取到的多尺度特征
            query_embeds, #[900,512]  [num_query,embed_dims*2]
            reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501 reg_banches是回归分支
            img_metas=img_metas,
        )

DETR3D这里的transformer只有decoder,没有encoder,整个transformer的代码如下:

		assert query_embed is not None
        bs = mlvl_feats[0].size(0)
        # 首先将query_embed分为query 和 query_pos
        query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)  # [1, 900, 256] 1是batch_size
        query = query.unsqueeze(0).expand(bs, -1, -1)   # [1, 900, 256] 1是batch_size
        # 通过linear和sigmoid从query_pos中获取到reference_points
        reference_points = self.reference_points(query_pos)
        reference_points = reference_points.sigmoid()   #[1, 900, 3] [bs,num_query,3]
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2)      #[1,900,256] --> [900, 1, 256] [num_query,bs,256]
        query_pos = query_pos.permute(1, 0, 2)  #[1,900,256] --> [900, 1, 256][num_query,bs,256]

      	# 进入到Detr3DTransformerDecoder
        inter_states, inter_references = self.decoder(
            query=query, # [num_query,bs,256]
            key=None,   
            value=mlvl_feats, #value就是提取出的图片特征 [num_level,bs,num_cam,c,h,w]
            query_pos=query_pos, # [num_query,bs,256]
            reference_points=reference_points, #[bs,num_query,3]
            reg_branches=reg_branches,
            **kwargs)

        inter_references_out = inter_references
        #inter_states是sample到的feature   inter_references_out是更新后的referencepoints
        return inter_states, init_reference_out, inter_references_out

decoder

decoder中先做self_attn,此时QKV都是query,shape为[900,bs,265],然后做cross_atten,在cross_attn中

        if key is None:
            key = query   # key就等于query
        if value is None:
            value = key

        if residual is None:
            inp_residual = query
        if query_pos is not None:
            query = query + query_pos

key和query是一样的,value是多尺度的feature map,虽然这里有了key,但是其实也没有用K乘Q去计算attention_weight,其attention_weight依然是通过query出的。

        query = query.permute(1, 0, 2)      # (1,900,256)

        bs, num_query, _ = query.size()      #bs=1, num_query=900

        # (1,1,900,12,1,4) num_cams=12 num_points=1 num_levels=4
        attention_weights = self.attention_weights(query).view(
            bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)

        # 返回值reference_points_3d就是原来的3d坐标,output是sampled_feats
        # output=B, C, num_query, num_cam,  1, len(mlvl_feats)] reference_points_3d=[1,900,3]
        reference_points_3d, output, mask = feature_sampling(
            value, reference_points, self.pc_range, kwargs['img_metas'])
        output = torch.nan_to_num(output)
        mask = torch.nan_to_num(mask)

        attention_weights = attention_weights.sigmoid() * mask
        # 个人理解:这里的output就是attention中的value
        output = output * attention_weights
        # 连续三个sum(-1),将不同尺度和不同相机的feature求和,得到最终图像的特征
        output = output.sum(-1).sum(-1).sum(-1)
        output = output.permute(2, 0, 1)
        # 图像特征project到与query同维度 [bs,256,num_query,num_cam,1,num_level] --> [num_query, bs, 256]
        output = self.output_proj(output)     
        # (num_query, bs, embed_dims)
        pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) #MLP
		
		# 输出 = sampled到的feature,原来的query,reference_points_3d的pos_feat
        return self.dropout(output) + inp_residual + pos_feat

最重要的部分就是feature_sampling这个函数

def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
    lidar2img = []
    for img_meta in img_metas:
        lidar2img.append(img_meta['lidar2img'])
    lidar2img = np.asarray(lidar2img)
    lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) (1,6,4,4)
    reference_points = reference_points.clone()
    reference_points_3d = reference_points.clone()
    # 归一化坐标  pc_range =[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
    # 将坐标从0-1尺度转到lidar坐标系下
    reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]  #x轴
    reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]  #y轴
    reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]  #z轴

    # reference_points (B, num_queries, 4)   在最后一列全加上1 变成(1,900,4)
    reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
    ###############################################
    # 2.由lidar系转化为camera系 
    ###############################################
    B, num_query = reference_points.size()[:2]         #B=1,num_query=900
    num_cam = lidar2img.size(1)   # num_cam=6
    # reference_points[1,900,4] --> reference_points.view(B, 1, num_query, 4) [1,1,900,4] --> repeat [1,12,900,4] --> [1,12,900,4,1]
    reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
    lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) #[1, 12, 900, 4, 4]
    # reference_points_cam.size() = [1,6,900,4]
    reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
    ###############################################
    # 3.由camera系转到图像系并归一化
    ###############################################
    eps = 1e-5
    # mask.size() = [1,6,900,1]
    mask = (reference_points_cam[..., 2:3] > eps)
    # 这一步是将坐标由camera系转到图像系 (x,y) = (xc,yc) / zc *f  这里的f是相机焦距,在前面lidar2img已经成过了,这里只用除以zc就行了
    reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
        reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)  # 深度归一化 		(1, 6, 900, 2)
     # 在img平面上进行长宽归一化
    reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] # 长宽归一化
    reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
	#将坐标由[0,1] 转到[-1,1]之间
    reference_points_cam = (reference_points_cam - 0.5) * 2
    # 对所有不在grid内的点,也就是投影在某个cam之外的点进行mask
    mask = (mask & (reference_points_cam[..., 0:1] > -1.0) 
                 & (reference_points_cam[..., 0:1] < 1.0) 
                 & (reference_points_cam[..., 1:2] > -1.0) 
                 & (reference_points_cam[..., 1:2] < 1.0))
    # mask.size()=[1,1,900,6,1,1]
    mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
    mask = torch.nan_to_num(mask)
    sampled_feats = []
    # 逐特征层sample feature
    for lvl, feat in enumerate(mlvl_feats):
        B, N, C, H, W = feat.size()   #B=1,N=6 C=256 H=16 W=28
        feat = feat.view(B*N, C, H, W)
        reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
        sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
        sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
        sampled_feats.append(sampled_feat)
    sampled_feats = torch.stack(sampled_feats, -1)
    sampled_feats = sampled_feats.view(B, C, num_query, num_cam,  1, len(mlvl_feats))
    return reference_points_3d, sampled_feats, mask

在通过feature_sampling提取特征后,将得到的output首先和attention_weigth相乘,然后连续三个sum(-1),将不同相机,不同尺度的feature直接相加,将这些特征都融合再一起,再通过一个out_proj的线性层,将这些特征转换到与query相同的维度,最后的输出就是提取到的特征和原始的query以及reference_points3d的位置编码的和。

在经过每一层的decoder layer之后,会有一个回归分支来预测bbox,会根据预测出的bbox来更新reference point

整个Detr3DTransformerDecoder的代码如下:

  		output = query
        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            reference_points_input = reference_points
            # 由此进入DetrTransformerDecoderLayer
            # 返回的output为[900,1,256]
            output = layer(
                output,
                *args,
                reference_points=reference_points_input,
                **kwargs)
            output = output.permute(1, 0, 2)
			
            if reg_branches is not None:
                tmp = reg_branches[lid](output)
                
                assert reference_points.shape[-1] == 3

                new_reference_points = torch.zeros_like(reference_points)
				# x y
                new_reference_points[..., :2] = tmp[
                    ..., :2] + inverse_sigmoid(reference_points[..., :2])
                 # z
                new_reference_points[..., 2:3] = tmp[
                    ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
                
                new_reference_points = new_reference_points.sigmoid()

                reference_points = new_reference_points.detach()

            output = output.permute(1, 0, 2)
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(
                intermediate_reference_points)

        return output, reference_points

整个decoder部分返回的是每一个decoder layer输出的query和reference point

走完整个transformer之后,后面就是通过输出的每一层的feature和referencepoint来进行预测。

你可能感兴趣的:(源码阅读,3d,深度学习,python)