3D点云深度学习PointNet源码解析——数据预处理

provider.py文件主要为PointNet提供数据加载以及点云预处理等功能
其import如下:

import os
import sys
import numpy as np
import h5py

紧接着import的是对数据目录的一些处理:

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)

其中os.path.abspath(__file__)获取当前文件的绝对路径,例如”E:\test\provider.py“os.path.dirname()则将该文件的绝对路径中的文件名取出,BASE_DIR = "E:\test",最后将其加入系统路径中,然后进行点云的下载,代码如下:

""" Download dataset for point cloud classification"""
DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR):
    os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
    www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
    zipfile = os.path.basename(www)
    os.system('wget %s; unzip %s' % (www, zipfile))
    os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
    os.system('rm %s' % (zipfile))

第一行通过将BASE_DIR'data'合生成存放数据的路径DATA_DIR,第一个if判断该路径是否存在,如果不存在则通过mkdir生产该路径。第二个if判断DATA_DIR路径下'modelnet40_ply_hdf5_2048'是否存在,若不存在则下载文件并解压


此外,该文件中定义了若干函数如下:

def shuffle_data(data, labels):
    """ Shuffle data and labels.
        Input:
          data: B,N,... numpy array
          label: B,... numpy array
        Return:
          shuffled data, label and shuffle indices
    """
    idx = np.arange(len(labels))
    np.random.shuffle(idx)
    return data[idx, ...], labels[idx], idx

shuffle_data用于打乱datalabels,首先生成同0-labels.len的等差数组,然后打乱这个数组。最后返回用这个被打乱的数组重新索引的datalabels以及该数组


def rotate_point_cloud(batch_data):
    """ Randomly rotate the point clouds to augument the dataset
        rotation is per shape based along up direction
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, rotated batch of point clouds
    """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in range(batch_data.shape[0]):
        rotation_angle = np.random.uniform() * 2 * np.pi
        cosval = np.cos(rotation_angle)
        sinval = np.sin(rotation_angle)
        rotation_matrix = np.array([[cosval, 0, sinval],
                                    [0, 1, 0],
                                    [-sinval, 0, cosval]])
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
    return rotated_data

首先生成和batch_data相同shaperotated_data用于存放旋转后的点云。循环batch_data的第一维batch_data.shape[0]也就是相当于对batch_data中的每一个 N3 N ∗ 3 的点云数据做旋转处理。np.random.uniform()*2*np.pi生成一个 [0,2π] [ 0 , 2 π ] 之间的角度rotation_angle,并计算其余弦值cosval和正弦值sinval,并生成其绕 Y Y 轴旋转rotation_angle的旋转矩阵rotation_matrix。最后用点云数据点乘旋转矩阵并放入rotated_data中返回


def rotate_point_cloud_by_angle(batch_data, rotation_angle):
    """ Rotate the point cloud along up direction with certain angle.
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, rotated batch of point clouds
    """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in range(batch_data.shape[0]):
        #rotation_angle = np.random.uniform() * 2 * np.pi
        cosval = np.cos(rotation_angle)
        sinval = np.sin(rotation_angle)
        rotation_matrix = np.array([[cosval, 0, sinval],
                                    [0, 1, 0],
                                    [-sinval, 0, cosval]])
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
    return rotated_data

此函数也是用于旋转点云,不同之处在一其旋转角度不是随机生产而是由rotation_angle指定


def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
    """ Randomly jitter points. jittering is per point.
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, jittered batch of point clouds
    """
    B, N, C = batch_data.shape
    assert(clip > 0)
    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
    jittered_data += batch_data
    return jittered_data

该函数用于对点云添加符合正态分布的随机扰动,np.clip()将其输入sigma*np.random.randn(B,N,C)限定在 [clip,clip] [ − c l i p , c l i p ] 之间,并加在batch_data


def getDataFiles(list_filename):
    return [line.rstrip() for line in open(list_filename)]

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    return (data, label)

def loadDataFile(filename):
    return load_h5(filename)

def load_h5_data_label_seg(h5_filename):
    f = h5py.File(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    seg = f['pid'][:]
    return (data, label, seg)


def loadDataFile_with_seg(filename):
    return load_h5_data_label_seg(filename)

用于读取数据,标签等

你可能感兴趣的:(PointNet)