3D Object Detection入门——PointRCNN代码学习

学习目录

  • 一. 预备知识
    • 1. KITTI数据格式
    • 2. Pointnet++/Pointnet2
  • 二. 整体流程
    • 1. 生成ground truth数据集
    • 2. RPN网络训练
    • 3.RCNN训练

一. 预备知识

1. KITTI数据格式

在PointRCNN训练过程中,需要用到KITTI中的四类数据。

  1. calib (calibration, 校准文件)
    是一个txt文件,每个文件有7行,用来描述相机的校准参数。七行的开头分别为:

    • P0,P1,P2,P3:分别表示左侧灰度相机,右侧灰度相机,左侧彩色相机,右侧彩色相机。
      每个相机(行)共有12个数字,表示一个3x4的矩阵。不过最后需要用的是4x4的参数矩阵,转换方法是右下角添1,其余位置添0。下面的参数也需用同样方法扩展到4x4。
    • R0_rect:校正旋转参数,可根据该参数将多个摄像机拍摄的照片位于同一个平面上。9个数字,表示3x3的矩阵。
    • Tr_velo_to_cam:3x4,顾名思义,可根据该矩阵的参数将lidar点云数据投影到未校正的相机照片上。
    • Tr_imu_to_velo:3x4,从imu坐标转换到velo的坐标。
  2. velodyne
    激光雷达点云文件的bin格式,可见其可视化方法。
    组织格式大概是(nx4),n为点云中点的个数 ,4中的前3个数是点的坐标,最后一个数代表激光反射强度

  3. lable_2
    是对对应图片/点云的标注文件,共15列。
    分别表示目标类别、截断率、遮挡程度、观察角度(-pi~pi)、左、上、右、下的2D边界框坐标、3维的高、宽、长(米)、照相机坐标下的物体位置3D坐标、相对y轴的旋转角度。

    Van 0.00 1 1.60 801.70 162.11 837.52 202.62 2.52 2.10 5.64 13.98 1.89 47.96 1.88
    Car 0.00 2 1.61 786.59 180.70 820.40 208.38 1.44 1.65 2.96 10.61 1.89 39.54 1.87
    DontCare -1 -1 -10 829.35 159.40 867.93 195.90 -1 -1 -1 -1000 -1000 -1000 -10

  4. image_2
    由车载高分辨率彩色摄像机行驶中拍摄的照片,png格式。

2. Pointnet++/Pointnet2

在PointRCNN的stage 1中,采用了Pointnet++作为主干网络,它是Pointnet的改进版。
Pointnet的核心思想是点云中点的置换不变性旋转不变性
其网络结构大概是将3通道的点云向高维度映射,使用卷积进行3->64->128->1024编码,再将其分别送入分割网络和分类网络。这样直接将NxD的点云矩阵作为输入相比输入体素等方法可以获得很好的实时性。但由于直接提取全局特征,导致局部特征并没有被有效的提取,导致PointNet在分割任务中表现较差。
Pointnet++则改进了这个问题,类似传统特征提取的方法,Pointnet++先从每个小范围提取局部特征,再将每个小范围进行abstraction,通过逐层的提取局部特征、再将其合成全局特征,从而进行分类。
两篇Pointnet++很细致的博客,学习的榜样。
https://zhuanlan.zhihu.com/p/88238420
https://blog.csdn.net/weixin_39373480/article/details/88878629

二. 整体流程

1. 生成ground truth数据集

从kitti中读取数据,为后续训练调用数据提供了丰富的接口。具体工作如下:

  • 从ImageSets中读取需要的数据名,然后从kitti中找到对应的calib、velo、image以及lable。(generate_gt_database.py:53-59)
  • 根据lable过滤对象。如,目标为car时,仅挑选class属于background和car的对象出来。
  • 从上面的对象中提取3d box,然后为该box中的每个点生成一个用于分割前景点(如:车)的掩码
    具体的掩码生成过程应该是在roipool3d_utils.py的41行,即roipool3d_cuda.pts_in_boxes3d_cpu(pts_flag, pts, boxes3d),论文中指出可通过lable的边界框,将3D框之内的点都认为是前景点。
  • 最后通过此掩码来提取点云中的点,然后把相关信息(sample_id、cls_type、gt_box3d等)保存至./gt_database/train_gt_database_3level_Car.pkl

关键代码如下

    def generate_gt_database(self):
        gt_database = []
        for idx, sample_id in enumerate(self.image_idx_list):
            sample_id = int(sample_id)
            print('process gt sample (id=%06d)' % sample_id)

            pts_lidar = self.get_lidar(sample_id)
            calib = self.get_calib(sample_id)
            pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
            pts_intensity = pts_lidar[:, 3]

            obj_list = self.filtrate_objects(self.get_label(sample_id))# 如:对于background和car的classes(22行),过滤掉行人。找出有效类

            gt_boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32)
            for k, obj in enumerate(obj_list):
                gt_boxes3d[k, 0:3], gt_boxes3d[k, 3], gt_boxes3d[k, 4], gt_boxes3d[k, 5], gt_boxes3d[k, 6] \
                    = obj.pos, obj.h, obj.w, obj.l, obj.ry

            if gt_boxes3d.__len__() == 0:
                print('No gt object')
                continue

            boxes_pts_mask_list = roipool3d_utils.pts_in_boxes3d_cpu(torch.from_numpy(pts_rect), torch.from_numpy(gt_boxes3d))

            for k in range(boxes_pts_mask_list.__len__()):
                pt_mask_flag = (boxes_pts_mask_list[k].numpy() == 1)
                cur_pts = pts_rect[pt_mask_flag].astype(np.float32)
                cur_pts_intensity = pts_intensity[pt_mask_flag].astype(np.float32)
                sample_dict = {'sample_id': sample_id,
                               'cls_type': obj_list[k].cls_type,
                               'gt_box3d': gt_boxes3d[k],
                               'points': cur_pts,
                               'intensity': cur_pts_intensity,
                               'obj': obj_list[k]}
                gt_database.append(sample_dict)

        save_file_name = os.path.join(args.save_dir, '%s_gt_database_3level_%s.pkl' % (args.split, self.classes[-1]))
        with open(save_file_name, 'wb') as f:
            pickle.dump(gt_database, f)

        self.gt_database = gt_database
        print('Save refine training sample info file to %s' % save_file_name)

2. RPN网络训练

类似于FastRCNN等二维特征提取方法,PointRCNN同样采用two-stage进行Object Detection。流程如下

  • 如前所述,PointRCNN采用Pointnet++作为stage 1的骨干(特征提取)部分。
    加载上面的pkl等数据之后,将点云矩阵作为输入,获得xyz坐标以及特征,即backbone_xyzbackbone_features。Pointnet++的代码主要位于poingnet2_msg.py。
  • 论文提出在Pointnet++输出的特征之后添加一个用于输出前景点掩码的分类分割头,和一个用于生成3D box的回归头,分别送入这两个子网络,二者并行进行。获得rpn_clsrpn_reg,主要代码如下。(rpn.py)
    def forward(self, input_data):
        """
        :param input_data: dict (point_cloud)
        :return:
        """
        pts_input = input_data['pts_input']
        backbone_xyz, backbone_features = self.backbone_net(pts_input)  # (B, N, 3), (B, C, N)  ## 首先由pointnet++进行处理

        rpn_cls = self.rpn_cls_layer(backbone_features).transpose(1, 2).contiguous()  # (B, N, 1)
        rpn_reg = self.rpn_reg_layer(backbone_features).transpose(1, 2).contiguous()  # (B, N, C)

        ret_dict = {'rpn_cls': rpn_cls, 'rpn_reg': rpn_reg,
                    'backbone_xyz': backbone_xyz, 'backbone_features': backbone_features}

        return ret_dict
  • rpn_cls可以认为是前景点的概率,当该概率大于某阈值时设对应掩码值为1,由此获得seg_mask。另外由于前景点的数量往往远小于背景点,使用focal损失来解决类不平衡的问题。
  • 输出rpn_reg的网络相对复杂一些,在论文中提出的是一种基于bin的候选框回归方法,但是对应的代码部分(rpn.py:31-59)还没有看懂。最后每batch输出512个region of interest。

3.RCNN训练

stage 2主要用于微调stage 1生成的3D框位置。

  • 首先将box中的池化点转换到相应的正则坐标系。(kitti_rcnn_dataset.py:246-)不过这一步之前应该也执行过。目的是为了优化训练。
  • 然后就是proposal_target_layer,输入数据包括stage 1 backbone网络输出的rpn_xyz、rpn_features;掩码seg_mask,回归网络输出的3维roi、深度信息pts_depth(正则化后会丢失)、还有ground truth box。(proposal_target_layer.py)
  • 文章采用类似的基于bin的回归损失。如果一个ground-trurh box和一个3D box proposals的IoU大于0.55,那么就把这个gt box分给3D box proposals来学习box微调。(proposal_target_layer.py:17)
  • 以上输入数据(input_data)会被送入proposal_target_layer,输出一个target_dict。两个dict的属性列表如下图 (rcnn_net.py:123)
    3D Object Detection入门——PointRCNN代码学习_第1张图片
  • 最后在point_rcnn.py中获得rcnn的输出。

你可能感兴趣的:(语义分割)