本文记录了博主学习frustum pointnets过程中遇到的2D和3D数据库显示程序。为了画出输出结果,博主希望在这个程序的基础上修改一个可以显示结果的程序。更新于2018.09.22。
本文首先给出代码原文的学习笔记,随后整理出修改后的结果显示程序,如果公开,会在这里放上链接,如果有帮助,请在代码页面点一下小星星哦。附可能有用的信息:各个集合所用的文件名在kitti/image_sets
文件夹下。
把总结写在前面,根据需要判断是否需要详细看源码分析。
objects[0].print_object()
中[]里面的标号可以指定画第几个目标,不过要注意的是,每个图片中含有的目标个数是不同的。这一部分记录了代码原文中出现的语法规则,并不影响代码功能的理解,但是可能方便日后的使用,因此在这里记录下来。
from __future__ import print_function
加上这句话以后,即使在python2.X也要像python3.X一样的语法使用print函数(加括号)。类似地,如果有其他新的功能特性且该特性与当前版本中的使用不兼容,就可以从future模块导入。详细说明参考这里。
from PIL import Image
PIL已经是python平台事实上的图像处理标准库了,全称为Python Imaging Library。具体的使用方法说明可以参考这里。
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
其中,__file__
就是当前所执行的文件,也就是kitti_object.py
。os.path.abspath
命令获取的是当前文件的绝对路径,比如博主的运行结果:
>>> print os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py")
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py
而前面的os.path.dirname
获取的则是当前路径所存在于的文件夹,因此,BASE_DIR指向的就是kitti_object.py
所处的文件夹了。运行结果为:
>>> print os.path.dirname(os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py"))
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
其中,os.path.join
用于路径拼接。
sys.path.append
:在导入一个模块时,默认情况下python会搜索当前目录、已安装的内置模块和第三方模块,搜索路径存放在sys模块的path中。如果要用的模块和当前脚本不在一个目录下,就需要将其添加到path中。这种修改是临时的,脚本运行后失效。
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
opencv中提供了cvtColor函数用于实现图像格式类型的相互转换,具体说明可以参照这里。
#代码作者信息
''' Helper class and functions for loading KITTI objects
Author: Charles R. Qi
Date: September 2017
'''
#加载必要的库
from __future__ import print_function
import os
import sys
import numpy as np
import cv2
from PIL import Image
#定义基础路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) #指向当前文件所在文件夹(kitti)
ROOT_DIR = os.path.dirname(BASE_DIR) #指向frustum文件夹
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
import kitti_util as utils #加载论文作者写的库
try:
raw_input # Python 2
except NameError:
raw_input = input # Python 3
#用于获取数据库各项路径路径(training或testing)
class kitti_object(object):
'''Load and parse object data into a usable format.'''
def __init__(self, root_dir, split='training'):
'''root_dir contains training and testing folders'''
self.root_dir = root_dir
self.split = split
self.split_dir = os.path.join(root_dir, split)
if split == 'training':
self.num_samples = 7481
elif split == 'testing':
self.num_samples = 7518
else:
print('Unknown split: %s' % (split))
exit(-1)
self.image_dir = os.path.join(self.split_dir, 'image_2')
self.calib_dir = os.path.join(self.split_dir, 'calib')
self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
self.label_dir = os.path.join(self.split_dir, 'label_2')
# 用于后面获取样本库内的样本总数
def __len__(self):
return self.num_samples
def get_image(self, idx):
assert(idx=xmin) & \
(pts_2d[:,1]=ymin)
fov_inds = fov_inds & (pc_velo[:,0]>clip_distance)
imgfov_pc_velo = pc_velo[fov_inds,:]
if return_more:
return imgfov_pc_velo, pts_2d, fov_inds
else:
return imgfov_pc_velo
def show_lidar_with_boxes(pc_velo, objects, calib,
img_fov=False, img_width=None, img_height=None):
''' Show all LiDAR points.
Draw 3d box in LiDAR point cloud (in velo coord system) '''
if 'mlab' not in sys.modules: import mayavi.mlab as mlab
from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
print(('All point num: ', pc_velo.shape[0]))
fig = mlab.figure(figure=None, bgcolor=(0,0,0),
fgcolor=None, engine=None, size=(1000, 500))
if img_fov:
pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0,
img_width, img_height)
print(('FOV point num: ', pc_velo.shape[0]))
draw_lidar(pc_velo, fig=fig)
for obj in objects:
if obj.type=='DontCare':continue
# Draw 3d bounding box
box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P)
box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d)
# Draw heading arrow
ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P)
ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d)
x1,y1,z1 = ori3d_pts_3d_velo[0,:]
x2,y2,z2 = ori3d_pts_3d_velo[1,:]
draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig)
mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5),
tube_radius=None, line_width=1, figure=fig)
mlab.show(1)
def show_lidar_on_image(pc_velo, img, calib, img_width, img_height):
''' Project LiDAR points to image '''
imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo,
calib, 0, 0, img_width, img_height, True)
imgfov_pts_2d = pts_2d[fov_inds,:]
imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)
import matplotlib.pyplot as plt
cmap = plt.cm.get_cmap('hsv', 256)
cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255
for i in range(imgfov_pts_2d.shape[0]):
depth = imgfov_pc_rect[i,2]
color = cmap[int(640.0/depth),:]
cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])),
int(np.round(imgfov_pts_2d[i,1]))),
2, color=tuple(color), thickness=-1)
Image.fromarray(img).show()
return img
def dataset_viz():
dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object')) #获取数据库各项路径
for data_idx in range(len(dataset)): #从0开始到len获取的数据库样本总数
# 从数据库中加载数据
objects = dataset.get_label_objects(data_idx) #获取data_idx对应的结果
objects[0].print_object() #在屏幕上输出data_idx对应的第一个结果(如果有多个,修改[]内的值就可以变成对应的结果
img = dataset.get_image(data_idx) #获取data_idx对应的图片
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #图像格式转换
img_height, img_width, img_channel = img.shape
print(('Image shape: ', img.shape))
pc_velo = dataset.get_lidar(data_idx)[:,0:3] #获取data_idx对应的3D点云
calib = dataset.get_calibration(data_idx)
# 在图像上画出2d和3dboxes
show_image_with_boxes(img, objects, calib, False)
raw_input()
# Show all LiDAR points. Draw 3d box in LiDAR point cloud
show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height)
raw_input()
if __name__=='__main__':
import mayavi.mlab as mlab
from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
dataset_viz() #显示数据