frustum pointnets训练代码学习笔记——kitti_object.py

frustum pointnets训练代码学习笔记——kitti_object.py

本文记录了博主学习frustum pointnets过程中遇到的2D和3D数据库显示程序。为了画出输出结果,博主希望在这个程序的基础上修改一个可以显示结果的程序。更新于2018.09.22。

本文首先给出代码原文的学习笔记,随后整理出修改后的结果显示程序,如果公开,会在这里放上链接,如果有帮助,请在代码页面点一下小星星哦。附可能有用的信息:各个集合所用的文件名在kitti/image_sets文件夹下。

更多内容,欢迎加入星球讨论。
frustum pointnets训练代码学习笔记——kitti_object.py_第1张图片

文章目录

  • frustum pointnets训练代码学习笔记——kitti_object.py
    • 总结
    • 用到的语法规则
      • `from __future__ import print_function`
      • `from PIL import Image`
      • `BASE_DIR = os.path.dirname(os.path.abspath(__file__))`
      • `sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))`
      • `img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)`
    • 代码原文分析

总结

把总结写在前面,根据需要判断是否需要详细看源码分析。

  • 这个文件的主要功能就是将KITTI库中的2d和3d结果画出来(至于是画training还是testing,文件的kitti_object函数的初始函数和get_label_objects分别有定义和判断)。
  • 文件个数是人为在kitti_object函数中设定的,并非自动提取。
  • 画图是机械第从第一个图片一直向后显示,且如果该图片中有多个目标,仅显示txt文件中排在第一的那个的数据。通过修改objects[0].print_object()中[]里面的标号可以指定画第几个目标,不过要注意的是,每个图片中含有的目标个数是不同的。

用到的语法规则

这一部分记录了代码原文中出现的语法规则,并不影响代码功能的理解,但是可能方便日后的使用,因此在这里记录下来。

from __future__ import print_function

加上这句话以后,即使在python2.X也要像python3.X一样的语法使用print函数(加括号)。类似地,如果有其他新的功能特性且该特性与当前版本中的使用不兼容,就可以从future模块导入。详细说明参考这里。

from PIL import Image

PIL已经是python平台事实上的图像处理标准库了,全称为Python Imaging Library。具体的使用方法说明可以参考这里。

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

其中,__file__就是当前所执行的文件,也就是kitti_object.pyos.path.abspath命令获取的是当前文件的绝对路径,比如博主的运行结果:

>>> print os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py")
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py

而前面的os.path.dirname获取的则是当前路径所存在于的文件夹,因此,BASE_DIR指向的就是kitti_object.py所处的文件夹了。运行结果为:

>>> print os.path.dirname(os.path.abspath("/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti/kitti_object.py"))
/home/galaxy/Work/XXX/Pointnets/frustum-pointnets-master/kitti

sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))

其中,os.path.join用于路径拼接。
sys.path.append:在导入一个模块时,默认情况下python会搜索当前目录、已安装的内置模块和第三方模块,搜索路径存放在sys模块的path中。如果要用的模块和当前脚本不在一个目录下,就需要将其添加到path中。这种修改是临时的,脚本运行后失效。

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

opencv中提供了cvtColor函数用于实现图像格式类型的相互转换,具体说明可以参照这里。


代码原文分析

#代码作者信息
''' Helper class and functions for loading KITTI objects

Author: Charles R. Qi
Date: September 2017
'''

#加载必要的库
from __future__ import print_function			

import os
import sys
import numpy as np
import cv2
from PIL import Image
#定义基础路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))			#指向当前文件所在文件夹(kitti)
ROOT_DIR = os.path.dirname(BASE_DIR)			#指向frustum文件夹
sys.path.append(os.path.join(ROOT_DIR, 'mayavi'))
import kitti_util as utils			#加载论文作者写的库

try:
	raw_input          # Python 2
except NameError:
   	raw_input = input  # Python 3


#用于获取数据库各项路径路径(training或testing)
class kitti_object(object):
	'''Load and parse object data into a usable format.'''

	def __init__(self, root_dir, split='training'):
    	'''root_dir contains training and testing folders'''
    	self.root_dir = root_dir
    	self.split = split
    	self.split_dir = os.path.join(root_dir, split)

    	if split == 'training':
        	self.num_samples = 7481
    	elif split == 'testing':
        	self.num_samples = 7518
    	else:
        	print('Unknown split: %s' % (split))
        	exit(-1)

    	self.image_dir = os.path.join(self.split_dir, 'image_2')
    	self.calib_dir = os.path.join(self.split_dir, 'calib')
    	self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
    	self.label_dir = os.path.join(self.split_dir, 'label_2')

	# 用于后面获取样本库内的样本总数
	def __len__(self):
    	return self.num_samples

	def get_image(self, idx):
    	assert(idx=xmin) & \
    	(pts_2d[:,1]=ymin)
	fov_inds = fov_inds & (pc_velo[:,0]>clip_distance)
	imgfov_pc_velo = pc_velo[fov_inds,:]
	if return_more:
    	return imgfov_pc_velo, pts_2d, fov_inds
	else:
    	return imgfov_pc_velo

def show_lidar_with_boxes(pc_velo, objects, calib,
                      img_fov=False, img_width=None, img_height=None): 
	''' Show all LiDAR points.
    	Draw 3d box in LiDAR point cloud (in velo coord system) '''
	if 'mlab' not in sys.modules: import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d

	print(('All point num: ', pc_velo.shape[0]))
	fig = mlab.figure(figure=None, bgcolor=(0,0,0),
    	fgcolor=None, engine=None, size=(1000, 500))
	if img_fov:
    	pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0,
        	img_width, img_height)
    	print(('FOV point num: ', pc_velo.shape[0]))
	draw_lidar(pc_velo, fig=fig)

	for obj in objects:
    	if obj.type=='DontCare':continue
    	# Draw 3d bounding box
    	box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P) 
    	box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d)
    	# Draw heading arrow
    	ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P)
    	ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d)
    	x1,y1,z1 = ori3d_pts_3d_velo[0,:]
    	x2,y2,z2 = ori3d_pts_3d_velo[1,:]
    	draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig)
   	 	mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5),
        	tube_radius=None, line_width=1, figure=fig)
	mlab.show(1)

def show_lidar_on_image(pc_velo, img, calib, img_width, img_height):
	''' Project LiDAR points to image '''
	imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo,
    	calib, 0, 0, img_width, img_height, True)
	imgfov_pts_2d = pts_2d[fov_inds,:]
	imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo)

	import matplotlib.pyplot as plt
	cmap = plt.cm.get_cmap('hsv', 256)
	cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255

	for i in range(imgfov_pts_2d.shape[0]):
    	depth = imgfov_pc_rect[i,2]
    	color = cmap[int(640.0/depth),:]
    	cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])),
        	int(np.round(imgfov_pts_2d[i,1]))),
        	2, color=tuple(color), thickness=-1)
	Image.fromarray(img).show() 
	return img

def dataset_viz():
	dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object'))			#获取数据库各项路径

	for data_idx in range(len(dataset)):			#从0开始到len获取的数据库样本总数
    	# 从数据库中加载数据
    	objects = dataset.get_label_objects(data_idx)			#获取data_idx对应的结果
    	objects[0].print_object()			#在屏幕上输出data_idx对应的第一个结果(如果有多个,修改[]内的值就可以变成对应的结果
    	img = dataset.get_image(data_idx)			#获取data_idx对应的图片
    	img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 			#图像格式转换
    	img_height, img_width, img_channel = img.shape
    	print(('Image shape: ', img.shape))
    	pc_velo = dataset.get_lidar(data_idx)[:,0:3]			#获取data_idx对应的3D点云
    	calib = dataset.get_calibration(data_idx)

    	# 在图像上画出2d和3dboxes
    	show_image_with_boxes(img, objects, calib, False)
    	raw_input()
    	# Show all LiDAR points. Draw 3d box in LiDAR point cloud
    	show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height)
    	raw_input()

if __name__=='__main__':
	import mayavi.mlab as mlab
	from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d
	dataset_viz()			#显示数据

你可能感兴趣的:(论文代码学习)