2018年07月19日 16:04:48 Vodake 阅读数:1160
provider.py文件主要为PointNet提供数据加载以及点云预处理等功能
其import如下:
import os
import sys
import numpy as np
import h5py
紧接着import的是对数据目录的一些处理:
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
其中os.path.abspath(__file__)
获取当前文件的绝对路径,例如”E:\test\provider.py“
,os.path.dirname()
则将该文件的绝对路径中的文件名取出,BASE_DIR = "E:\test"
,最后将其加入系统路径中,然后进行点云的下载,代码如下:
""" Download dataset for point cloud classification"""
DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www)
os.system('wget %s; unzip %s' % (www, zipfile))
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system('rm %s' % (zipfile))
第一行通过将BASE_DIR
和'data'
合生成存放数据的路径DATA_DIR
,第一个if
判断该路径是否存在,如果不存在则通过mkdir
生产该路径。第二个if
判断DATA_DIR
路径下'modelnet40_ply_hdf5_2048'
是否存在,若不存在则下载文件并解压
此外,该文件中定义了若干函数如下:
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
shuffle_data
用于打乱data
和labels
,首先生成同0-labels.len
的等差数组,然后打乱这个数组。最后返回用这个被打乱的数组重新索引的data
,labels
以及该数组
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
首先生成和batch_data
相同shape
的rotated_data
用于存放旋转后的点云。循环batch_data
的第一维batch_data.shape[0]
也就是相当于对batch_data
中的每一个N∗3N∗3的点云数据做旋转处理。np.random.uniform()*2*np.pi
生成一个[0,2π][0,2π]之间的角度rotation_angle
,并计算其余弦值cosval
和正弦值sinval
,并生成其绕YY轴旋转rotation_angle
的旋转矩阵rotation_matrix
。最后用点云数据点乘旋转矩阵并放入rotated_data
中返回
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
此函数也是用于旋转点云,不同之处在一其旋转角度不是随机生产而是由rotation_angle
指定
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
该函数用于对点云添加符合正态分布的随机扰动,np.clip()
将其输入sigma*np.random.randn(B,N,C)
限定在[−clip,clip][−clip,clip]之间,并加在batch_data
上
def getDataFiles(list_filename):
return [line.rstrip() for line in open(list_filename)]
def load_h5(h5_filename):
f = h5py.File(h5_filename)
data = f['data'][:]
label = f['label'][:]
return (data, label)
def loadDataFile(filename):
return load_h5(filename)
def load_h5_data_label_seg(h5_filename):
f = h5py.File(h5_filename)
data = f['data'][:]
label = f['label'][:]
seg = f['pid'][:]
return (data, label, seg)
def loadDataFile_with_seg(filename):
return load_h5_data_label_seg(filename)
用于读取数据,标签等