小白科研笔记:简析SA-SSD中的数据增强机制

1. 前言

当前,目标检测的深度网络一般训练至少50个Epoch。如果每个Epoch喂入的数据都是一样的,网络的误差函数可能会快速下降,但是会有很大的几率出现过拟合。这就像高中的笔者一直死记硬背地做一套习题,掌握不了知识变通一样。为了融汇贯通知识点,就需要做各式各样的练习题。在深度学习里,这就是数据增强(Data Augmentation),把输入的数据进行变形(Ground Truth也随之变化)。在点云处理中,常见的数据增强方式有:

  1. 点云上的部分点加高斯噪声
  2. 点云翻转(称之为flip),关于xyz轴做镜像翻转
  3. 点云旋转,关于xyz轴做旋转
  4. 点云放大/缩小

当点云做了上述的变换后,Ground Truth也会随着变动。比如3D目标检测来说,点云做旋转或翻转后,目标的真值3D框也需要做同样的旋转或者翻转。使用数据增强后,每一个Epoch数据都是不太一样的,目标检测网络也会学会了随机应变检测目标的能力,进而提高了它的泛化能力。

2. SA-SSD中的数据增强

2.1 四种点云数据增强的方式

直接查看SA-SSD中数据增强的代码(位于kitti.py中):

    # 使用数据增强后,输入数据和真值标签同时做变换
    # 下面是常见的四种数据增强方式:
    self.augmentor.noise_per_object_(gt_bboxes, points, num_try=100) 
    gt_bboxes, points = self.augmentor.random_flip(gt_bboxes, points) 
    gt_bboxes, points = self.augmentor.global_rotation(gt_bboxes, points) 
    gt_bboxes, points = self.augmentor.global_scaling(gt_bboxes, points) 

上述代码中的self.augmentor对应的是PointAugmentor类(位于point_augmentor.py)。接下来会分析这四种数据增强的具体操作方式。

2.2 方式一:加噪声

这一节分析函数noise_per_object_。代码如下所示:

    # 函数自变量都比较直观,不去做解释
    def noise_per_object_(self,
                          gt_boxes,
                          points=None,
                          valid_mask=None,
                          num_try=100):
        """random rotate or remove each groundtrutn independently.
        use kitti viewer to test this function points_transform_

        Args:
            gt_boxes: [N, 7], gt box in lidar.points_transform_
            points: [M, 4], point cloud in lidar.
        """
        num_boxes = gt_boxes.shape[0]

		# valid_mask 表示哪些点需要加噪声;哪些点不需要加;一般默认是全体点云做增强
        if valid_mask is None:
            valid_mask = np.ones((num_boxes,), dtype=np.bool_)
        center_noise_std = np.array(self._center_noise_std, dtype=gt_boxes.dtype)
        
        # 位置噪声,高斯分布噪声
        loc_noises = np.random.normal(
            scale=center_noise_std, size=[num_boxes, num_try, 3])

		# 旋转噪声,均匀分布噪声
        rot_noises = np.random.uniform(
            self._global_rot_range[0], self._global_rot_range[1], size=[num_boxes, num_try])

        origin = [0.5, 0.5, 0]
        gt_box_corners = center_to_corner_box3d(gt_boxes, origin=origin, axis=2)

		# 位置噪声加在与大地所在的平面上,只有 xy 和 wl 和 theta 会变化
        selected_noise = noise_per_box(gt_boxes[:, [0, 1, 3, 4, 6]],
                                           valid_mask, loc_noises, rot_noises)

        loc_transforms = select_transform(loc_noises, selected_noise)
        rot_transforms = select_transform(rot_noises, selected_noise)
        surfaces = corner_to_surfaces_3d_jit(gt_box_corners)

		# 位姿噪声作用在点云和真值框上
        if points is not None:
            point_masks = points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
            points_transform_(points, gt_boxes[:, :3], point_masks, loc_transforms,
                              rot_transforms, valid_mask)

        box3d_transform_(gt_boxes, loc_transforms, rot_transforms, valid_mask)

其实这一段代码还是比较复杂的,咱暂先弄懂它的大致意图哈。如果后续需要深入理解,再去做更详细的分析吧。

2.3 方式二:做翻转

这一节分析random_flip

    def random_flip(self, gt_boxes, points, probability=0.5):
        enable = np.random.choice(
            [False, True], replace=False, p=[1 - probability, probability])
        if enable:
            gt_boxes[:, 1] = -gt_boxes[:, 1]
            gt_boxes[:, 6] = -gt_boxes[:, 6] + np.pi
            points[:, 1] = -points[:, 1]
        return gt_boxes, points

这段代码写得挺好玩的。做翻转指的是做y轴对称的翻转。

2.4 方式三:做旋转

这一节分析global_rotation

    def global_rotation(self, gt_boxes, points):
        noise_rotation = np.random.uniform(self._global_rot_range[0], \
                                           self._global_rot_range[1])
        points[:, :3] = rotation_points_single_angle(
            points[:, :3], noise_rotation, axis=2)
        gt_boxes[:, :3] = rotation_points_single_angle(
            gt_boxes[:, :3], noise_rotation, axis=2)
        gt_boxes[:, 6] += noise_rotation
        return gt_boxes, points

做旋转指的是在z轴上做旋转。

2.5 方式四:做缩放

这一节分析global_scaling

    def global_scaling(self, gt_boxes, points):
        noise_scale = np.random.uniform(self._min_scale, self._max_scale)
        points[:, :3] *= noise_scale
        gt_boxes[:, :6] *= noise_scale
        return gt_boxes, points

做缩放指的是在xyz轴上做缩放。

3. 结语

数据增强的方式比较简单,但是在实际过程中还是很有用的。

你可能感兴趣的:(computer,vision论文代码分析)