作者使用generate_gt_database.py生成储存了数据集所有Car的gt box的信息的文件,包括每个gt box的:
首先定义kitti_dataset,定义通用接口,初始化data的寻找路径等
# lib/datasets/kitti_dataset.py
class KittiDataset(torch_data.Dataset):
def __init__(self, root_dir, split='train'):
self.split = split
is_test = self.split == 'test'
self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training')
split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt')
self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
self.num_sample = self.image_idx_list.__len__()
self.image_dir = os.path.join(self.imageset_dir, 'image_2')
self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
self.calib_dir = os.path.join(self.imageset_dir, 'calib')
self.label_dir = os.path.join(self.imageset_dir, 'label_2')
self.plane_dir = os.path.join(self.imageset_dir, 'planes')
def get_image(self, idx):
def get_image_shape(self, idx):
def get_lidar(self, idx):
def get_calib(self, idx):
def get_label(self, idx):
def get_road_plane(self, idx):
def __len__(self):
def __getitem__(self, item):
然后定义PointRCNN特殊的dataset,主要是完成提取数据,数据增广等操作。这里主要看准备用于训练rpn的数据。其实代码中的注释已经写的非常好了,这里就直接写一下都做了些什么:
# lib/datasets/kitti_rcnn_dataset.py
def get_rpn_sample(self, index):
sample_id = int(self.sample_id_list[index])
if sample_id < 10000:
calib = self.get_calib(sample_id)
# img = self.get_image(sample_id)
img_shape = self.get_image_shape(sample_id)
pts_lidar = self.get_lidar(sample_id)
# get valid point (projected points should be in image)
# 将pts转换到cam0坐标系内
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
pts_intensity = pts_lidar[:, 3]
else:
calib = self.get_calib(sample_id % 10000)
# img = self.get_image(sample_id % 10000)
img_shape = self.get_image_shape(sample_id % 10000)
pts_file = os.path.join(self.aug_pts_dir, '%06d.bin' % sample_id)
assert os.path.exists(pts_file), '%s' % pts_file
aug_pts = np.fromfile(pts_file, dtype=np.float32).reshape(-1, 4)
pts_rect, pts_intensity = aug_pts[:, 0:3], aug_pts[:, 3]
# 将pts_rect投影到cam2的图像坐标系,pts_imgs为(u,v)坐标
pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
# 将pts_imgs在图像外的去掉,将pts_rect在给定边界外的去掉,得到保留点的flag。图像的外的点通过(u,v)不在图像内获得,pts_rect的点根据cfg获得,x∈[-40,40],y∈[-1,3],z∈[0,70.1]
pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
pts_rect = pts_rect[pts_valid_flag][:, 0:3]
pts_intensity = pts_intensity[pts_valid_flag]
if cfg.GT_AUG_ENABLED and self.mode == 'TRAIN':
# all labels for checking overlapping
# 去掉是‘DonotCare’的obj
all_gt_obj_list = self.filtrate_dc_objects(self.get_label(sample_id))
all_gt_boxes3d = kitti_utils.objs_to_boxes3d(all_gt_obj_list) # Nx7 (x,y,z,h,w,l,ry)
gt_aug_flag = False
if np.random.rand() < cfg.GT_AUG_APPLY_PROB:
# augment one scene
# 添加其他场景中的obj到这个场景。
# gt_aug_flag是True代表加入了新的obj
# pts_rect, pts_intensity都是加入了新的点之后的点云和强度(if gt_aug_flag)
# extra_gt_boxes3d, extra_gt_obj_list是新加入的(if gt_aug_flag)
gt_aug_flag, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list = \
self.apply_gt_aug_to_one_scene(sample_id, pts_rect, pts_intensity, all_gt_boxes3d)
# generate inputs
# 将点降采样或者补充成16384个
if self.mode == 'TRAIN' or self.random_select:
if self.npoints < len(pts_rect):
pts_depth = pts_rect[:, 2]
pts_near_flag = pts_depth < 40.0
far_idxs_choice = np.where(pts_near_flag == 0)[0]
near_idxs = np.where(pts_near_flag == 1)[0]
near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)
choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
if len(far_idxs_choice) > 0 else near_idxs_choice
np.random.shuffle(choice)
else:
choice = np.arange(0, len(pts_rect), dtype=np.int32)
if self.npoints > len(pts_rect):
extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
choice = np.concatenate((choice, extra_choice), axis=0)
np.random.shuffle(choice)
ret_pts_rect = pts_rect[choice, :]
ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5]
else:
ret_pts_rect = pts_rect
ret_pts_intensity = pts_intensity - 0.5
pts_features = [ret_pts_intensity.reshape(-1, 1)]
ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]
sample_info = {'sample_id': sample_id, 'random_select': self.random_select}
if self.mode == 'TEST':
if cfg.RPN.USE_INTENSITY:
pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
else:
pts_input = ret_pts_rect
sample_info['pts_input'] = pts_input
sample_info['pts_rect'] = ret_pts_rect
sample_info['pts_features'] = ret_pts_features
return sample_info
gt_obj_list = self.filtrate_objects(self.get_label(sample_id))
if cfg.GT_AUG_ENABLED and self.mode == 'TRAIN' and gt_aug_flag:
gt_obj_list.extend(extra_gt_obj_list)
gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
gt_alpha = np.zeros((gt_obj_list.__len__()), dtype=np.float32)
for k, obj in enumerate(gt_obj_list):
gt_alpha[k] = obj.alpha
# data augmentation
aug_pts_rect = ret_pts_rect.copy()
aug_gt_boxes3d = gt_boxes3d.copy()
if cfg.AUG_DATA and self.mode == 'TRAIN':
# rotation,scale,flip
aug_pts_rect, aug_gt_boxes3d, aug_method = self.data_augmentation(aug_pts_rect, aug_gt_boxes3d, gt_alpha,
sample_id)
sample_info['aug_method'] = aug_method
# prepare input
if cfg.RPN.USE_INTENSITY:
pts_input = np.concatenate((aug_pts_rect, ret_pts_features), axis=1) # (N, C)
else:
pts_input = aug_pts_rect
if cfg.RPN.FIXED:
sample_info['pts_input'] = pts_input
sample_info['pts_rect'] = aug_pts_rect
sample_info['pts_features'] = ret_pts_features
sample_info['gt_boxes3d'] = aug_gt_boxes3d
return sample_info
# generate training labels
rpn_cls_label, rpn_reg_label = self.generate_rpn_training_labels(aug_pts_rect, aug_gt_boxes3d)
sample_info['pts_input'] = pts_input
sample_info['pts_rect'] = aug_pts_rect
sample_info['pts_features'] = ret_pts_features
sample_info['rpn_cls_label'] = rpn_cls_label
sample_info['rpn_reg_label'] = rpn_reg_label
sample_info['gt_boxes3d'] = aug_gt_boxes3d
return sample_info
PointRCNN是CVPR2019中3D目标检测的文章。3D目标检测是一个计算机视觉中比较新的任务,其他的文献综述可以参考我的另外一篇博客3D Object Detection 3D目标检测综述
该文章使用two-stage方式,利用PointNet++作为主干网络,先完成segmentation任务,判断每个三维点的label。对分为前景的每个点,使用feature生成框。然后对框进行roi crop,进行框的优化。该论文网络复杂,代码量巨大,真是佩服论文作者的代码功底,自愧不如。本文着重对网络结构的理解。代码来源是文章作者给出的代码,用的是pytorch,github传送门
接下来,我将先对运算过程进行可视化,然后再列出部分本篇论文我注意到的点。
由于PointRCNN使用PointNet++作为主干网络,所以对PointNet++的具体网络结构不是很了解的同学可以参考我的另一篇博客PointNet++具体实现详解,其中也是着重对网络结构的理解。先看PointRCNN的网络结构的可视化:
PointRCNN是two-stage结构的网络,所以训练过程也是先训练RPN,再训练RCNN。