OpenPCDet 自定义数据集训练

目录

0、目标:

1、数据的预处理

2、修改数据处理部分的代码

2.1 复制对数据集进行处理的文件

2.2 对kitti_lidar_dataset.py进行修改

2.2.1 头文件修改

2.2.2 数据集对象名称修改

2.2.3 get_info函数修改

2.2.4 .yaml文件修改

2.2.5 运行

3、修改数据集加载

3.1 去掉测试

3.2 修改__getitem__函数

 3.3 前后连起来

3.4 .yaml文件修改

3.5运行


·本文还存在错误,对点云并未进行坐标转化,选择性阅读

0、目标:

本文立足于pointpillars算法的训练,这里通过处理kitti数据集展示对自定义数据集的训练方法。

在源代码中对pointpillars的训练需要很多的数据(不晓得咋直接训练可以进入这篇博客OpenPCDet 在KITTI 训练PointPillar_辉e的博客-CSDN博客_openpcdet训练kitti)这里尤其是calib,我们对点云进行目标检测的训练,不需要啥坐标转换的,所以我这里想去除这个文件夹,只依靠velodyne和label来进行训练

1、数据的预处理

这里写了一个代码进行数据的预处理,其目的主要是对label的第12、13、14位进行处理,因为kitti数据集中这个标注的意思是在相机坐标系下其标注框的位置(x ,y ,z),而我们在使用过程中需要获得雷达坐标系下的标注,所以在这里进行预先的转化。

1、该代码写在tools文件夹中,kitti数据集在data文件夹中

2、运行下面的py文件会建立一个文件夹data/kitti/training/new_label_2,并将处理过然后产生的txt文件放入其中

3、运行完代码后将new_label_2名字改为label_2(原谅我是懒蛋,如果不改这个地方,会有很多其他地方要改)

import numpy as np
from pathlib import Path
import os

def get_calib_from_file(calib_file):
    with open(calib_file) as f:
        lines = f.readlines()

    obj = lines[2].strip().split(' ')[1:]
    P2 = np.array(obj, dtype=np.float32)
    obj = lines[3].strip().split(' ')[1:]
    P3 = np.array(obj, dtype=np.float32)
    obj = lines[4].strip().split(' ')[1:]
    R0 = np.array(obj, dtype=np.float32)
    obj = lines[5].strip().split(' ')[1:]
    Tr_velo_to_cam = np.array(obj, dtype=np.float32)

    return {'P2': P2.reshape(3, 4),
            'P3': P3.reshape(3, 4),
            'R0': R0.reshape(3, 3),
            'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)}


class Calibration(object):
    def __init__(self, calib_file):
        if not isinstance(calib_file, dict):
            calib = get_calib_from_file(calib_file)
        else:
            calib = calib_file

        self.P2 = calib['P2']  # 3 x 4
        self.R0 = calib['R0']  # 3 x 3
        self.V2C = calib['Tr_velo2cam']  # 3 x 4

        # Camera intrinsics and extrinsics
        self.cu = self.P2[0, 2]
        self.cv = self.P2[1, 2]
        self.fu = self.P2[0, 0]
        self.fv = self.P2[1, 1]
        self.tx = self.P2[0, 3] / (-self.fu)
        self.ty = self.P2[1, 3] / (-self.fv)

    def cart_to_hom(self, pts):
        """
        :param pts: (N, 3 or 2)
        :return pts_hom: (N, 4 or 3)
        """
        pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32)))
        return pts_hom

    #对R0_rect进行拓展,然后与Tr_velo_to_cam进行相乘求相反数后再求逆 R0_rect * Tr_velo_to_cam * y=x(y是雷达,x是照相机)
    def rect_to_lidar(self, pts_rect):
        """
        :param pts_lidar: (N, 3)
        :return pts_rect: (N, 3)
        """
        pts_rect_hom = self.cart_to_hom(pts_rect)  # (N, 4)
        R0_ext = np.hstack((self.R0, np.zeros((3, 1), dtype=np.float32)))  # (3, 4)
        R0_ext = np.vstack((R0_ext, np.zeros((1, 4), dtype=np.float32)))  # (4, 4)
        R0_ext[3, 3] = 1
        V2C_ext = np.vstack((self.V2C, np.zeros((1, 4), dtype=np.float32)))  # (4, 4)
        V2C_ext[3, 3] = 1

        pts_lidar = np.dot(pts_rect_hom, np.linalg.inv(np.dot(R0_ext, V2C_ext).T))
        return pts_lidar[:, 0:3]

class Object3d(object):
    def __init__(self, line):
        label = line.strip().split(' ')
        self.top=np.array([])
        for i in range(0,11):
            self.top=np.append(self.top,label[i])
        self.loc = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32)
        self.last=np.array([label[14]])


def get_calib(root_split_path, idx):
    calib_file = root_split_path / 'calib' / ('%s.txt' % idx)
    assert calib_file.exists()
    return Calibration(calib_file)

def get_objects_from_label(label_file):
    with open(label_file, 'r') as f:
        lines = f.readlines()
    objects = [Object3d(line) for line in lines]
    return objects

def get_label(root_split_path, idx):
    label_file = root_split_path / 'label_2' / ('%s.txt' % idx)
    assert label_file.exists()
    return get_objects_from_label(label_file)

def write_new_libel(root_split_path, idx, save_num):
    new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)
    with open(new_libel_file, "a")as f:
        f.write(str(save_num[0]))
        for i in range(1,save_num.shape[0]):
            f.write(' '+str(save_num[i]))
        f.write('\r\n')

#去掉文件最后的换行符
def del_n(root_split_path,idx):
    new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)
    file_object = open(new_libel_file, "rb+")
    file_object.seek(-2,2)
    file_object.truncate()
    file_object.close()

def get_allfile(path):  # 获取所有文件
    all_file = []
    files =sorted(os.listdir(path))
    for f in files :  #listdir返回文件中所有目录
        #f_name = os.path.join(path, f)
        #f_name=os.path.basename(f_name)#去掉路径
        f=os.path.splitext(f)[0]#去掉文件名后缀
        all_file.append(f)
    return all_file
    
def clean_file(root_split_path,idx):
    new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)
    file_object = open(new_libel_file, "w")
    file_object.close()

def mkdir_new_label_2(root_split_path):
    new_libel_2=root_split_path / 'new_label_2'
    if os.path.exists(new_libel_2) is False:
        print("-------mkdir%s-------"%new_libel_2) 
        os.mkdir(new_libel_2)

root_split_path=Path('../data/kitti/training')

mkdir_new_label_2(root_split_path)
all_file=get_allfile(root_split_path/'label_2')  #tickets要获取文件夹名
print("-------All name loaded-------")
#print(all_file)

for file_idx in all_file:
    clean_file(root_split_path,file_idx)
    print("This is the %s.txt"%file_idx)
    calib=get_calib(root_split_path,file_idx)
    obj_list=get_label(root_split_path,file_idx)
    annotations = {}
    for obj in obj_list:
        annotations['location'] = np.concatenate([obj.loc.reshape(1, 3)], axis=0)
        #print(annotations['location'])
        loc_lidar = calib.rect_to_lidar(annotations['location'])
        loc_lidar=loc_lidar.reshape(-1)
        #print("top",obj.top[0])
        temp=np.concatenate([obj.top,loc_lidar,obj.last],axis=0)
        #print("concatenate",temp)
        write_new_libel(root_split_path, file_idx, temp)
    #del_n(root_split_path, file_idx)

2、修改数据处理部分的代码

OpenPCDet中首先对数据进行了一波预处理,我们仿照着写一下,这一步主要是对pcdet/datasets这个文件夹进行处理

2.1 复制对数据集进行处理的文件

把pcdet/datasets/kitti文件夹复制并改名为pcdet/datasets/kitti_lidar,然后把pcdet/utils/object3d_kitti.py复制为pcdet/utils/object3d_kitti_lidar.py

2.2 对kitti_lidar_dataset.py进行修改

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py

2.2.1 头文件修改

这一行修改最后的object3d_kitti为object3d_kitti_lidar

from ...utils import box_utils, calibration_kitti, common_utils, object3d_kitti_lidar

2.2.2 数据集对象名称修改

头文件下面一行修改为(原类名为KittiDataset)

class KittiLidarDataset(DatasetTemplate):

2.2.3 get_info函数修改

这里其他的地方不要改,直接到这个函数,然后替换为下面的代码

    def get_infos(self, num_workers=4, has_label=True, count_inside_pts=True, sample_id_list=None):
        import concurrent.futures as futures

        def process_single_scene(sample_idx):
            print('%s sample_idx: %s' % (self.split, sample_idx))
            info = {}
            pc_info = {'num_features': 4, 'lidar_idx': sample_idx}
            info['point_cloud'] = pc_info

            if has_label:
                obj_list = self.get_label(sample_idx)
                annotations = {}
                annotations['name'] = np.array([obj.cls_type for obj in obj_list])
                annotations['truncated'] = np.array([obj.truncation for obj in obj_list])
                annotations['occluded'] = np.array([obj.occlusion for obj in obj_list])
                annotations['alpha'] = np.array([obj.alpha for obj in obj_list])
                annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)
                annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list])  # lhw(camera) format
                annotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)
                annotations['rotation_y'] = np.array([obj.ry for obj in obj_list])
                annotations['score'] = np.array([obj.score for obj in obj_list])
                annotations['difficulty'] = np.array([obj.level for obj in obj_list], np.int32)

                num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare'])
                num_gt = len(annotations['name'])
                index = list(range(num_objects)) + [-1] * (num_gt - num_objects)
                annotations['index'] = np.array(index, dtype=np.int32)

                loc = annotations['location'][:num_objects]
                dims = annotations['dimensions'][:num_objects]
                rots = annotations['rotation_y'][:num_objects]
                #loc_lidar = calib.rect_to_lidar(loc)#获得一个变换矩阵
                loc_lidar=loc
                l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]
                loc_lidar[:, 2] += h[:, 0] / 2
                gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)
                annotations['gt_boxes_lidar'] = gt_boxes_lidar

                info['annos'] = annotations

            return info

        sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)

2.2.4 .yaml文件修改

老规矩,先cv,将tools/cfgs/dataset_configs/kitti_dataset.yaml复制为tools/cfgs/dataset_configs/kitti_lidar.yaml。然后修改一下第一行,修改为

DATASET: 'KittiLidarDataset'

2.2.5 运行

终端输入

python -m pcdet.datasets.kitti_lidar.kitti_lidar_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_lidar.yaml

结果展示:

OpenPCDet 自定义数据集训练_第1张图片

 然后我们的pkl文件就存放在data/kitti里面啦

3、修改数据集加载

3.0 复制数据集加载的文件

把pcdet/datasets/kitti_lidar文件夹复制并改名为pcdet/datasets/kitti_lidar,里面的文件相应改名

OpenPCDet 自定义数据集训练_第2张图片

3.1 去掉测试

tools/train.py这个文件夹内部,去掉测试的代码,我们少修改一点

"""
    logger.info('**********************Start evaluation %s/%s(%s)**********************' %
                (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
    test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers, logger=logger, training=False
    )
    eval_output_dir = output_dir / 'eval' / 'eval_with_train'
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    args.start_epoch = max(args.epochs - args.num_epochs_to_eval, 0)  # Only evaluate the last args.num_epochs_to_eval epochs

    repeat_eval_ckpt(
        model.module if dist_train else model,
        test_loader, args, eval_output_dir, logger, ckpt_dir,
        dist_test=dist_train
    )
    logger.info('**********************End evaluation %s/%s(%s)**********************' %
                (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
"""

3.2 修改__getitem__函数

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py,修改数据加载的文件,这里主要把图像和calib的加载去掉,然后把我们新的数据集文件(label_2)导入

    def __getitem__(self, index):
        # index = 4
        if self._merge_all_iters_to_one_epoch:
            index = index % len(self.kitti_infos)

        info = copy.deepcopy(self.kitti_infos[index])

        sample_idx = info['point_cloud']['lidar_idx']
        #img_shape = info['image']['image_shape']
        #calib = self.get_calib(sample_idx)
        get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points'])

        input_dict = {
            'frame_id': sample_idx,
            #'calib': calib,
        }

        if 'annos' in info:
            annos = info['annos']
            annos = common_utils.drop_info_with_name(annos, name='DontCare')
            loc, dims, rots = annos['location'], annos['dimensions'], annos['rotation_y']
            gt_names = annos['name']
            #gt_boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32)
            gt_boxes_lidar = annos['gt_boxes_lidar']

            input_dict.update({
                'gt_names': gt_names,
                'gt_boxes': gt_boxes_lidar
            })
            #if "gt_boxes2d" in get_item_list:
            #    input_dict['gt_boxes2d'] = annos["bbox"]

            road_plane = self.get_road_plane(sample_idx)
            if road_plane is not None:
                input_dict['road_plane'] = road_plane

        if "points" in get_item_list:
            points = self.get_lidar(sample_idx)
            input_dict['points'] = points


        data_dict = self.prepare_data(data_dict=input_dict)

        #data_dict['image_shape'] = img_shape
        return data_dict

 3.3 前后连起来

pcdet/datasets/__init__.py,将前面的部分和后面的部分连起来

#头文件中加入,我们2.2.2
from .kitti_lidar.kitti_lidar_dataset import KittiLidarDataset

__all__ = {
    'DatasetTemplate': DatasetTemplate,
    'KittiDataset': KittiDataset,
    'KittiLidarDataset':KittiLidarDataset,#相应的这里也加入
    'NuScenesDataset': NuScenesDataset,
    'WaymoDataset': WaymoDataset,
    'PandasetDataset': PandasetDataset,
    'LyftDataset': LyftDataset
}

3.4 .yaml文件修改

将tools/cfgs/kitti_models/pointpillar.yaml复制到tools/cfgs/kitti_lidar_models/pointpillar.yaml,kitti_lidar_models这个文件夹自己建立

其中修改_BASE_CONFIG_

DATA_CONFIG: 
    _BASE_CONFIG_: cfgs/dataset_configs/kitti_lidar.yaml

3.5运行

cd tools
python train.py --cfg_file=cfgs/kitti_lidar_models/pointpillar.yaml --batch_size=3 --epochs=100

运行结果:

你可能感兴趣的:(机器学习,深度学习,pytorch,人工智能)