BEV (3)---DETR3d

1 算法简介

1.1 算法思想

不同于LSS、BEVDepth的bottom-up式,先进行深度估计,设计2D转3D的模块。DETR3D是一种3D转2D的top-down思路。先预设一系列预测框的查询向量object querys,利用它们生成3D reference point,将这些3D reference point 利用相机参数转换矩阵,投影回2D图像坐标,并根据他们在图像的位置去找到对应的图像特征,用图像特征和object querys做cross-attention,不断refine object querys。最后利用两个MLP分支分别输出分类预测结果与回归预测结果。正负样本则采用和DETR相同的二分图匹配,即根据最小cost在900个object querys中找到与GT数量最匹配的N个预测框。由于正负样本匹配以及object querys这种查询目标的方式与DETR类似,因此可以看成是DETR在3D的扩展。

1.2 算法流程

①. 利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出(注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式)
②. 预设900个object_querys(类似于2D中的priobox先验框), 拆分object query为query和query_pos, 利用全连接将query_pos的维度由[900, 256]映射到[900, 3], 此时就获得了BEV空间3D reference point (x, y, z)的参考点。
③. 进入transformer decoder,共有6层decoder layer,其中在每层layer之中,令q=k=v=query,即所有的object query之间先做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。
④. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
⑤. 预测的3D reference point投影回2D中,可能无对应的点或者在当前相机下不可见,因此使用一个mask 表示3D reference point是否在当前相机位中。
⑥. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample(双线性插值)采样,获得与2D reference point对应的图像特征。
⑦. query作为attention权重,与图像特征进行cross-attention。
⑧. 用取到的特征去 refine(优化) 3D reference point,refine 的方式也非常简单粗暴,直接相加即可。
⑨. 利用全连接输出回归预测分支与分类预测分支
⑩. 匈牙利算法进行二分图匹配,获得正负样本,计算分类损失(focal loss)、回归损失(L1 loss)。

1.3 算法优缺点

优点:
①. 只查询object query对应的特征,没有完整显式地表示出整个BEV, 节省了内存和计算量,速度更快。
缺点:
①. 由于3D向2D投影时利用的是3D reference point这一物体中心点去fpn特征图中寻找特征,因此当感受野不足时,找到的特征可能不全,因此在实际应用中较长的目标比如bus,可能预测框偏小。
②. 同个BEV网格上的 reference point投影回2D采样的图像特征是相同的,缺乏深度信息,reference point和图像特征是否匹配需要通过不断的隐式学习去迭代。

2 代码解读

2.1 backbone + neck 特征提取

利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出:
注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式

    def extract_img_feat(self, img, img_metas):
        """Extract features of images."""
        B = img.size(0)
        if img is not None:
            input_shape = img.shape[-2:]
            # update real input shape of each single img
            for img_meta in img_metas:
                img_meta.update(input_shape=input_shape)
            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                # 合并batch和N的维度
                img = img.view(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)
            # Resnet输出1/4, 1/8, 1/16, 1/32尺度特征图
            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None
        # FPN
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)
        img_feats_reshaped = []
        for img_feat in img_feats:
            BN, C, H, W = img_feat.size()
            img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
        return img_feats_reshaped

2.2 利用全连接预测BEV空间3D reference point 坐标 (x, y, z)

拆分object query为query和query_pos, 利用全连接处理query_pos获得3D reference point。

class Detr3DTransformer(BaseModule):
    def forward(self,
                mlvl_feats,
                query_embed,
                reg_branches=None,
                **kwargs):
        """Forward function for `Detr3DTransformer`.
        Args:
        mlvl_feats:2d图像特征
 			query_embed:object querys:[num_query, 256]     
        """
        assert query_embed is not None
        bs = mlvl_feats[0].size(0)
        # 将object query的位置编码和query拆分,query_pos: [900, 256], query:[900, 256]
        query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
        # query_pos: [900, 256] ---> [1, 900, 256]
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
        # query: [900, 256] ---> [1, 900, 256]
        query = query.unsqueeze(0).expand(bs, -1, -1)
        # 全连接预测出在BEV空间的3D reference point offset坐标(x, y, z), query_pos: [1, 900, 256]--->[1, 900, 3]
        reference_points = self.reference_points(query_pos)
        # sigmoid约束offset到(0, 1)范围
        reference_points = reference_points.sigmoid()
        init_reference_out = reference_points

        # query: [1, 900, 256] ---> [900, 1, 256]
        query = query.permute(1, 0, 2)
        # query_pos: [1, 900, 256] ---> [900, 1, 256]
        query_pos = query_pos.permute(1, 0, 2)
        # decoder,inter_states: 3d参考点在fpn的采样特征, inter_references:修正后的3d参考点
        inter_states, inter_references = self.decoder(
            query=query,
            key=None,
            value=mlvl_feats,
            query_pos=query_pos,
            reference_points=reference_points,
            reg_branches=reg_branches,
            **kwargs)

        inter_references_out = inter_references
        return inter_states, init_reference_out, inter_references_out

2.3 transformer decoder layer

共有6层decoder layer,其中在每层layer之中,所有的object query之间做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。object query再和图像特征之间做cross-attention。

2.3.1 所有的object query之间做self-attention

此时Mutilhead Attention的参数q=k=v=query。该query为2.2中object query中的query。

            if layer == 'self_attn':
                temp_key = temp_value = query
                query = self.attentions[attn_index](
                    query,
                    temp_key,
                    temp_value,
                    identity if self.pre_norm else None,
                    query_pos=query_pos,
                    key_pos=query_pos,
                    attn_mask=attn_masks[attn_index],
                    key_padding_mask=query_key_padding_mask,
                    **kwargs)
2.3.2 object query和图像特征之间做cross-attention

①. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
②. 过滤越界点,获得满足条件的mask
③. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample采样,获得与2D reference point对应的图像特征。
④. query作为attention权重,与图像特征进行cross-attention

            elif layer == 'cross_attn':
                query = self.attentions[attn_index](
                    query,
                    key,
                    value,
                    identity if self.pre_norm else None,
                    query_pos=query_pos,
                    key_pos=key_pos,
                    attn_mask=attn_masks[attn_index],
                    key_padding_mask=key_padding_mask,
                    **kwargs)
class Detr3DCrossAtten(BaseModule):
    def forward(self,
                query,
                key,
                value,
                residual=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                **kwargs):

        if key is None:
            key = query
        if value is None:
            value = key

        if residual is None:
            inp_residual = query
        if query_pos is not None:
            query = query + query_pos

        # query: [900, 1, 256]--->[1, 900, 256]
        query = query.permute(1, 0, 2)
        # 1, 900, 256
        bs, num_query, _ = query.size()
        # 全连接, query: [1, 900, 256] --->[1, 900, 24]--->[1, 1, 900, 6, 1, 4]
        attention_weights = self.attention_weights(query).view(bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
        # 将3D 参考点利用转换矩阵转换到2D图像坐标系,通过2D reference_points的位置信息利用grid_sample去fpn输出特征层采样
        reference_points_3d, output, mask = feature_sampling(value, reference_points, self.pc_range, kwargs['img_metas'])
        # 替换nan的值为0
        output = torch.nan_to_num(output)
        mask = torch.nan_to_num(mask)
        # 保存满足边界条件的attention_weights
        attention_weights = attention_weights.sigmoid() * mask
        # query与图像特征进行cross-attention
        output = output * attention_weights
        # output: [1, 256, 900, 6, 1, 4]--->[1, 256, 900]
        output = output.sum(-1).sum(-1).sum(-1)
        # output: [900, 1, 256]
        output = output.permute(2, 0, 1)
        # 全连接 output: [900, 1, 256]
        output = self.output_proj(output)
        # pos_feat:[900, 1, 256]
        pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)

        return self.dropout(output) + inp_residual + pos_feat

3D转2D + 采样代码

def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
    lidar2img = []
    # 获得雷达坐标系转换矩阵
    for img_meta in img_metas:
        lidar2img.append(img_meta['lidar2img'])
    lidar2img = np.asarray(lidar2img)
    lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
    reference_points = reference_points.clone()
    reference_points_3d = reference_points.clone()
    # 从[-51.2, 51.2]转换为[0, 51.2]
    # offset_x*(51.2+51.2)-51.2
    reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
    # offset_y*(51.2+51.2)-51.2
    reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
    # offset_z*(3+5)-5
    reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
    # reference_points: [1, 900, 4]
    reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
    B, num_query = reference_points.size()[:2]
    num_cam = lidar2img.size(1)
    # reference_points: [1, 900, 4]--->[1, 1, 900, 4]--->[1, 6, 900, 4, 1]
    reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
    # lidar2img: [1, 6, 4, 4]--->[1, 6, 1, 4, 4]--->[1, 6, 900, 4, 4]
    lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
    # 左乘转换矩阵转图片坐标, reference_points_cam: [1, 6, 900, 4, 1]--->[1, 6, 900, 4]
    reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
    eps = 1e-5
    # 过滤offset_z > 0
    mask = (reference_points_cam[..., 2:3] > eps)
    # 除以深度坐标Zc,获得二维坐标(x,y)
    reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
    # 归一化
    reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]

    reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
    # 中心点转换为顶点
    reference_points_cam = (reference_points_cam - 0.5) * 2
    # 过滤越界
    mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
                 & (reference_points_cam[..., 0:1] < 1.0) 
                 & (reference_points_cam[..., 1:2] > -1.0) 
                 & (reference_points_cam[..., 1:2] < 1.0))
    # mask: [1, 6, 900, 1]---> [1, 6, 1, 900, 1, 1]--->[1, 1, 900, 6, 1, 1]
    mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)

    mask = torch.nan_to_num(mask)
    sampled_feats = []

    for lvl, feat in enumerate(mlvl_feats):
        B, N, C, H, W = feat.size()
        feat = feat.view(B*N, C, H, W)
        # reference_points_cam_lvl: [1, 6, 900, 2] --->[6, 900, 1, 2]
        reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
        # 利用F.grid_sample去feat上采样指定位置的点,feat: [6, 256, 116, 200], reference_points_cam_lvl:[6, 900, 1, 2], sampled_feat: [6, 256, 900, 1]
        sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
        # sampled_feat: [6, 256, 900, 1]--->[1, 6, 256, 900, 1]--->[1, 256, 900, 6, 1]
        sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
        # 保存每个特征层采样结果
        sampled_feats.append(sampled_feat)
    # 在最后一维stack, sampled_feats: [1, 256, 900, 6, 1, 4]
    sampled_feats = torch.stack(sampled_feats, -1)
    # sampled_feats: [1, 256, 900, 6, 1, 4]--->[1, 256, 900, 6, 1, 4]
    sampled_feats = sampled_feats.view(B, C, num_query, num_cam,  1, len(mlvl_feats))
    return reference_points_3d, sampled_feats, mask

3 正负样本匹配

和DETR类似,在所有object queries预测出来的预测框和所有GT box之间利用匈牙利算法进行二分图匹配,找到使得cost最小的最优匹配。
①. 分别计算分类cost:focal loss, 回归cost:l1 loss, 两者的和作为cost代价矩阵
②. 利用二分图匹配获得使cost最小的匹配结果

    def assign(self,
               bbox_pred,
               cls_pred,
               gt_bboxes, 
               gt_labels,
               gt_bboxes_ignore=None,
               eps=1e-7):
        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)

        # 1. 将GT 索引和类别索引初始化为-1, assigned_gt_inds:[900],  assigned_labels: [900]
        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)
        assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long)
        # No ground truth or boxes, return empty assignment
        if num_gts == 0 or num_bboxes == 0:
            if num_gts == 0:
                # 没有GT,则全部声明为背景0
                assigned_gt_inds[:] = 0
            return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)

        # 2. 分别计算分类cost与回归cost
        # 分类cost: focal loss
        cls_cost = self.cls_cost(cls_pred, gt_labels)
        # 归一化
        normalized_gt_bboxes = normalize_bbox(gt_bboxes, self.pc_range)
        # 回归cost: l1 cost
        reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])
      
        # weighted sum of above two costs
        cost = cls_cost + reg_cost
        
        # 3. 利用二分图匹配,获得最小cost匹配结果
        cost = cost.detach().cpu()

        if linear_sum_assignment is None:
            raise ImportError('Please run "pip install scipy" ''to install scipy first.')

        matched_row_inds, matched_col_inds = linear_sum_assignment(cost)

        matched_row_inds = torch.from_numpy(matched_row_inds).to(bbox_pred.device)

        matched_col_inds = torch.from_numpy(matched_col_inds).to(bbox_pred.device)

        # 4. assign backgrounds and foregrounds
        # assign all indices to backgrounds first
        assigned_gt_inds[:] = 0
        # assign foregrounds based on matching results
        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1

        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
        return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)

4 损失函数

分类损失: focal loss
回归损失: l1 loss

你可能感兴趣的:(BEV,3d,深度学习,机器学习)