从零实现一个3D目标检测算法(2):点云数据预处理

在2D目标检测中,一般不需要对图像进行预处理操作,直接输入原始图像即可得到最终的检测结果。
但是在点云3D目标检测中,往往需要对点云进行一定的预处理,本文将介绍在PointPillars模型中如何对点云进行预处理。这里的点云数据预处理操作同样也适用其它的基于Voxels的3D检测模型中

文章目录

    • 1. 模型配置文件config.py
      • 1.1将模型参数保存在日志文件
      • 1.2加载模型配置文件
      • 1.3解析终端命令修改模型配置参数
    • 2. 点云数据预处理
      • 2.1 DatasetTemplate类
      • 2.2 KittiDataset类
      • 2.3 KITTI数据加载器

1. 模型配置文件config.py

在这里我们将首先编写在整个工程中最重要的config.py文件,该文件主要包括三个函数。

作用分别是:加载模型配置文件pointpillar.yaml、将模型参数保存在日志文件中、以及解析终端命令修改模型配置参数。

关于以下三个函数,只需要会使用即可。首先导入需要的Python库:

from easydict import EasyDict
from pathlib import Path
import yaml

1.1将模型参数保存在日志文件

这一部分是将整个网络模型的全部参数保存到日志文件中,这在平时的日常工作中也是一项必须要做的事。

在开发过程中一个项目或一个模块的代码往往会修改很多次。有了日志文件,我们就能很方便地查看每次修改的地方,如果出现问题的话,也可以借助日志快速定位问题,代码如下:

def log_config_to_file(cfg, pre='cfg', logger=None):
    for key, val in cfg.items():
        if isinstance(cfg[key], EasyDict):
            logger.info('\n%s.%s = edict()' % (pre, key))
            log_config_to_file(cfg[key], pre=pre+ '.' + key, logger=logger)
            continue
        logger.info('%s.%s: %s' % (pre, key, val))

1.2加载模型配置文件

下面一个函数是从配置文件pointpillar.yaml中加载网络模型参数。

在Python中我们使用字典这种数据类型来存储网络的各种参数,只需要命名好参数名称即可,如测试集,训练集名称,网络各子模块名称,损失函数名称等,在修改时也只需要修改参数对应的变量值即可,这是一个很方便的调参方式。代码如下:

def cfg_from_yaml_file(cfg_file, config):
    with open(cfg_file, 'r') as f:
        try:
            new_config = yaml.load(f, Loader=yaml.FullLoader)
        except:
            new_config = yaml.load(f)
        config.update(EasyDict(new_config))
    
    return config

1.3解析终端命令修改模型配置参数

除了对模型配置文件.yaml进行修改外,也可以在执行时通过终端来修改模型的参数。

这时就要求程序能够获取终端信息,包括参数名称以及参数值,这通常是成对出现,代码如下:

def cfg_from_list(cfg_list, config):
    """Set config keys via list (e.g., from command line)."""
    from ast import literal_eval
    assert len(cfg_list) % 2 == 0
    for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
        key_list = k.split('.')
        d = config
        for subkey in key_list[:-1]:
            assert subkey in d, 'NotFoundKey: %s' % subkey
            d = d[subkey]
        subkey = key_list[-1]
        assert subkey in d, 'NotFoundKey: %s' % subkey
        try: 
            value = literal_eval(v)
        except:
            value = v 
        
        if type(value) != type(d[subkey]) and isinstance(d[subkey], EasyDict):
            key_val_list = value.split('.')
            for src in key_val_list:
                cur_key, cur_val = src.split(':')
                val_type = type(d[subkey][cur_key])
                cur_val = val_type(cur_val)
                d[subkey][cur_key] = cur_val
        elif type(value) != type(d[subkey]) and isinstance(d[subkey], list):
            val_list = value.split('.')
            for k, x in enumerate(val_list):
                val_list[k] = type(d[subkey][0])(x)
            d[subkey] = val_list
        else:
            assert type(value) == type(d[subkey]), \
                'type {} dose not match original type {}'.format(type(value), type(d[subkey]))
            d[subkey] = value

下面我们来定义模型参数配置变量cfg,其本身是一个字典,现在我们先定义它的根路径。

至此配置文件代码编写完毕,不妨可以调用cfg_from_yaml_file函数加载yaml文件看看模型参数加载是否正确。

cfg = EasyDict()
cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve()
cfg.LOCAL_RANK = 0

2. 点云数据预处理

现在我们对KITTI数据集进行预处理,最终将其加载到PyTorchDataLoader中。

2.1 DatasetTemplate类

我们使用Python中的Class来对点云数据进行预处理,数据的预处理操作都定义为Class的成员函数。

先首先定义一个DatasetTemplate类,当做点云数据的一个基本类,后面处理其它点云数据集时可以在此基础上进行不同的操作,导入必要的Python库:

import numpy as np
from collections import defaultdict
import torch.utils.data as torch_data
import sys 
sys.path.append('../')
sys.path.append('../../')
from utils import common_utils
from config import cfg

class DatasetTemplate(torch_data.Dataset):
	def __init__(self):
        super().__init__()

DatasetTemplate中我们定义两个成员函数,一个是数据准备函数prepare_data。输入的是点云数据帧编号和原始点云数据,以字典形式传输,输出为:VoxelsVoxels坐标、每个Voxels中点个数、Voxels中心坐标、原始点云数据,同样以字典形式返回。

def prepare_data(self, input_dict):
	"""
    :param input_dict:
    	sample_idx: string
        points: (N, 3 + C1)
    :return:
        voxels: (N, max_points_of_each_voxel, 3 + C2), float
        num_points: (N), int
        coordinates: (N, 3), [idx_z, idx_y, idx_x]
        voxel_centers: (N, 3)
        points: (M, 3 + C)
    """
    sample_idx = input_dict['sample_idx']
    points = input_dict['points']
    points = points[:, :cfg.DATA_CONFIG.NUM_POINT_FEATURES['use']]     

    # voxels, coordinates, num_points
    voxels, coordinates, num_points = self.voxel_generator.generate(points, \
                                      max_voxels=cfg.DATA_CONFIG[self.mode].MAX_NUMBER_OF_VOXELS)    
       
    # voxel_centers
    voxel_centers = (coordinates[:, ::-1] + 0.5) * self.voxel_generator.voxel_size \
                    + self.voxel_generator.point_cloud_range[0:3]
    print('voxel_centers.shape is: ', voxel_centers.shape)       # (11719, 3)
    if cfg.DATA_CONFIG.MASK_POINTS_BY_RANGE:
        points = common_utils.mask_points_by_range(points, cfg.DATA_CONFIG.POINT_CLOUD_RANGE)
    
    example = {
     }

    example.update({
     'voxels': voxels,
                    'num_points': num_points,
                    'coordinates': coordinates,
                    'voxel_centers': voxel_centers,
                    'points': points})
    
    return example
 
    

另一个函数是collate_batch,作用是加载数据集是如何选取数据。

@staticmethod
def collate_batch(batch_list, _unused=False):
	example_merged = defaultdict(list)
    for example in batch_list:
        for k, v in example.items():
            example_merged[k].append(v)
    ret = {
     }
    for key, elems in example_merged.items():
        if key in ['voxels', 'num_points', 'voxel_centers', 'seg_labels', 'part_labels', 'bbox_reg_labels']:
            ret[key] = np.concatenate(elems, axis=0)
        elif key in ['coordinates', 'points']:
            coors = []
            for i, coor in enumerate(elems):
                coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
                coors.append(coor_pad)
            ret[key] = np.concatenate(coors, axis=0)
        elif key in ['gt_boxes']:
            max_gt = 0
            batch_size = elems.__len__()
            for k in range(batch_size):
                max_gt = max(max_gt, elems[k].__len__())
            batch_gt_boxes3d = np.zeros((batch_size, max_gt, elems[0].shape[-1]), dtype=np.float32)
            for k in range(batch_size):
                batch_gt_boxes3d[k, :elems[k].__len__(), :] = elems[k]
            ret[key] = batch_gt_boxes3d
        else:
            ret[key] = np.stack(elems, axis=0)
    ret['batch_size'] = batch_list.__len__()
    return ret

2.2 KittiDataset类

现在我们编写kitti_dataset.py,主要目的是创造KittiDataset类,首先是导入所需库:

import os
import sys
import pickle 
import copy
import numpy as np
from pathlib import Path 
import torch 
import sys 
sys.path.append('../')
sys.path.append('../../')
# from utils import common_utils
from config import cfg 
from spconv.utils import VoxelGenerator
from ..dataset import DatasetTemplate

在这里我们首先定义一个BaseKittiDataset类,这里初始化只有一个参数,就是点云数据的存储路径root_path

class BaseKittiDataset(DatasetTemplate):
    def __init__(self, root_path):
        super().__init__()
        self.root_path = root_path

现在我们编写获取点云数据的get_lidar函数,KITTI中点云数据是以二进制格式保存的,每个点有4个信息: ( x , y , z , r ) (x,y,z,r) (x,y,z,r),数据类型为float32,代码如下:

def get_lidar(self, idx):
	lidar_file = os.path.join(self.root_path, 'velodyne', '%06d.bin' % idx)
    assert os.path.exists(lidar_file)
    return np.fromfile(lidar_file, dtype=np.float32).reshape([-1, 4])      

此外我们也可以编写函数get_infos来获取点云信息,具体为:

def get_infos(self, idx):
	import concurrent.futures as futures

    info = {
     }
    pc_info = {
     'num_features':4, 'lidar_idx': idx}
    info['point_cloud'] = pc_info
    return info

这里有一个生成最终预测结果的函数,因为模型计算时使用的是GPU,而要保存时需要转化为CPU可访问的数据。

预测信息有box尺寸box3d_lidar,分值scores,目标类型标签label_preds,以及点云编号sample_idx

@staticmethod
def generate_prediction_dict(input_dict, index, record_dict):
	# finally generate predictions.
    sample_idx = input_dict['sample_idx'][index] if 'sample_idx' in input_dict else -1
    boxes3d_lidar_preds = record_dict['boxes'].cpu().numpy()

    if boxes3d_lidar_preds.shape[0] == 0:
    	return {
     'sample_idx': sample_idx}
     
    predictions_dict ={
     
        'box3d_lidar': boxes3d_lidar_preds,
        'scores': record_dict['scores'].cpu.numpy(),
        'label_preds': record_dict['labels'].cpu().numpy(),
        'sample_idx': sample_idx
    }

    return predictions_dict

现在我们就可以创建KittiDataset类了,同样初始化时需要设置数据路径,我们需要将模式设置为TEST

class KittiDataset(BaseKittiDataset):
    def __init__(self, root_path, logger=None):
        super().__init__(root_path=root_path)

        self.logger = logger
        self.mode = 'TEST'
        self.kitti_infos = []
        self.include_kitti_data(self.mode, logger)
        self.dataset_init(logger)

在初始化时,有一个dataset_init函数,这个函数是用来生成voxel_generator的,使用的库为Spconv,在prepare_data函数中会使用这个voxel_generator,代码如下:

def dataset_init(self, logger):
	voxel_generator_cfg = cfg.DATA_CONFIG.VOXEL_GENERATOR
        
    self.voxel_generator = VoxelGenerator(voxel_size=voxel_generator_cfg.VOXEL_SIZE,
                                          point_cloud_range=cfg.DATA_CONFIG.POINT_CLOUD_RANGE,
                                          max_num_points=voxel_generator_cfg.MAX_POINTS_PER_VOXEL)

include_kitti_data函数是用来加载pkl文件的,我们会将待处理的点云信息存储在pkl文件中,这样测试模型时只需使用这一个文件就可以访问全部点云数据了:

def include_kitti_data(self, mode, logger):
	if cfg.LOCAL_RANK == 0 and logger is not None:
		logger.info('Loading KITTI dataset')
        kitti_infos = []

        for info_path in cfg.DATA_CONFIG[mode].INFO_PATH:        
            info_path = cfg.ROOT_DIR / info_path
            with open(info_path, 'rb') as f:
                infos = pickle.load(f)
                kitti_infos.append(infos)
        self.kitti_infos.extend(kitti_infos)

        if cfg.LOCAL_RANK == 0 and logger is not None:
            logger.info('Total samples for KITTI dataset: %d' % (len(kitti_infos)))

此外我们也可以对点云进行筛选,下面的代码为选取 x x x [ 0 , 70.4 ] [0, 70.4] [0,70.4] y y y [ − 40 , 40 ] [-40, 40] [40,40] z z z [ − 3 , 1 ] [-3, 1] [3,1]范围的点,这个一般要根据具体应用场景来设置。

@staticmethod
def get_valid_flag(pts_lidar):
	'''
	 Valid points should be in the PC_AREA_SCOPE
	 '''
	 val_flag_x = np.logical_and(pts_lidar[:, 0]>=0, pts_lidar[:, 0]<=70.4)
	 val_flag_y = np.logical_and(pts_lidar[:, 1]>=-40, pts_lidar[:, 1]<=40)
	 val_flag_z = np.logical_and(pts_lidar[:, 2]>=-3, pts_lidar[:, 2]<=1)
	 val_flag_merge = np.logical_and(val_flag_x, val_flag_y, val_flag_z)
	 pts_valid_flag = val_flag_merge
	 return pts_valid_flag

最后,就是编写__getitem__函数

def __len__(self):
	return len(self.kitti_infos)
    
def __getitem__(self, index):
    info = copy.deepcopy(self.kitti_infos[index])
    sample_idx = info['point_cloud']['lidar_idx']
       
    points = self.get_lidar(sample_idx)
    pts_valid_flag = self.get_valid_flag(points[:, 0:3])
    points = points[pts_valid_flag]
    input_dict = {
     'points': points, 'sample_idx': sample_idx}
    example = self.prepare_data(input_dict=input_dict)
    example['sample_idx'] = sample_idx
          
    return example

最后编写main函数,函数主要作用是获取终端信息,生成kitti_infos

if __name__=='__main__':
    if sys.argv.__len__() > 1 and sys.argv[1] == 'create_kitti_infos':
        create_kitti_infos(data_path=cfg.ROOT_DIR / 'data',
                           save_path=cfg.ROOT_DIR / 'data')

生成后的kitti_infos如下:

{
     'point_cloud': {
     'num_features': 4, 'lidar_idx': '000010'}}

2.3 KITTI数据加载器

这里的作用是通过DataLoader加载点云数据,这在PyTorch是十分常见的,代码如下:

import os
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from .kitti.kitti_dataset import KittiDataset, BaseKittiDataset
from config import cfg

__all__ = {
     'BaseKittiDataset': BaseKittiDataset,
           'KittiDataset': KittiDataset}

def build_dataloader(data_dir, batch_size, logger=None):
    data_dir = Path(data_dir) if os.path.isabs(data_dir) else cfg.ROOT_DIR / data_dir

    dataset = __all__[cfg.DATA_CONFIG.DATASET](root_path=data_dir, logger=logger)

    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, 
                            shuffle=False, collate_fn=dataset.collate_batch, drop_last=False)

    return dataset, dataloader

至此,点云数据预处理部分我们就已经完成了,预处理后的点云数据将变成如下形式:原始pointsvoxels及其坐标,中心位置,点云数量,点云编号等。下一篇文章中我们将开始实现PointPillars的网络部分。

input_dict`:{
     `voxels`, `num_points`, `coordinates`, `voxel_centers` ,  `points`, `sample_idx`,  `batch_size`}

你可能感兴趣的:(无人驾驶汽车进阶)