【代码阅读】mmdetection3d的数据预处理

数据预处理流程和数据集之间是互相分离的两个部分,通常数据集定义了如何处理标注信息,而数据预处理流程定义了准备数据项字典的所有步骤。数据集预处理流程包含一系列的操作,每个操作将一个字典作为输入,并输出应用于下一个转换的一个新的字典。
蓝色框表示预处理流程中的各项操作。随着预处理的进行,每一个操作都会添加新的键值(图中标记为绿色)到输出字典中,或者更新当前存在的键值(图中标记为橙色)。
【代码阅读】mmdetection3d的数据预处理_第1张图片
以之前的pointpillar模型的数据预处理流程为例:

train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        load_dim=5,
        use_dim=5,
        file_client_args=file_client_args),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        file_client_args=file_client_args),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.3925, 0.3925],
        scale_ratio_range=[0.95, 1.05],
        translation_std=[0, 0, 0]),
    dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
    dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectNameFilter', classes=class_names),
    dict(type='PointShuffle'),
    dict(type='DefaultFormatBundle3D', class_names=class_names),
    dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        load_dim=5,
        use_dim=5,
        file_client_args=file_client_args),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        file_client_args=file_client_args),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        pts_scale_ratio=1.0,
        flip=False,
        pcd_horizontal_flip=False,
        pcd_vertical_flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1., 1.],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict(
                type='PointsRangeFilter', point_cloud_range=point_cloud_range),
            dict(
                type='DefaultFormatBundle3D',
                class_names=class_names,
                with_label=False),
            dict(type='Collect3D', keys=['points'])
        ])
]

主要可以分为数据加载、预处理、格式化及数据增强三个部分:

数据加载

LoadPointsFromFile

  • 添加point

LoadPointsFromMultiSweeps

  • 更新:points

LoadAnnotations3D

  • 添加:gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels, pts_instance_mask, pts_semantic_mask, bbox3d_fields, pts_mask_fields, pts_seg_fields

数据预处理

GlobalRotScaleTrans

  • 添加:pcd_trans, pcd_rotation, pcd_scale_factor

  • 更新:points, *bbox3d_fields

RandomFlip3D

  • 添加:flip, pcd_horizontal_flip, pcd_vertical_flip

    更新:points, *bbox3d_fields

PointsRangeFilter

  • 更新:points

ObjectRangeFilter

  • 更新:gt_bboxes_3d, gt_labels_3d

ObjectNameFilter

  • 更新:gt_bboxes_3d, gt_labels_3d

PointShuffle

  • 更新:points

PointsRangeFilter

  • 更新:points

格式化

DefaultFormatBundle3D

  • 更新:points, gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels

Collect3D

  • 添加:img_meta (由 meta_keys 指定的键值构成的 img_meta)
  • 移除:所有除 keys 指定的键值以外的其他键值

数据增强

MultiScaleFlipAug

  • 更新: scale, pcd_scale_factor, flip, flip_direction, pcd_horizontal_flip, pcd_vertical_flip (与这些指定的参数对应的增强后的数据列表)

官方文档给出了怎么自己定义一些数据预处理的方法 ,首先我们定义一个函数 my_pipieline.py:

from mmdet.datasets import PIPELINES

@PIPELINES.register_module()
class MyTransform:

    def __call__(self, results):
        results['dummy'] = True
        return results

导入预定义好的类:

from .my_pipeline import MyTransform

随后便可以在pipeline直接添加即可:

train_pipeline = [
dict(
type=‘LoadPointsFromFile’,
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
dict(
type=‘LoadPointsFromMultiSweeps’,
sweeps_num=10,
file_client_args=file_client_args),
dict(type=‘LoadAnnotations3D’, with_bbox_3d=True, with_label_3d=True),
dict(
type=‘GlobalRotScaleTrans’,
rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]),
dict(type=‘RandomFlip3D’, flip_ratio_bev_horizontal=0.5),
dict(type=‘PointsRangeFilter’, point_cloud_range=point_cloud_range),
dict(type=‘ObjectRangeFilter’, point_cloud_range=point_cloud_range),
dict(type=‘ObjectNameFilter’, classes=class_names),
dict(type=‘MyTransform’),
dict(type=‘PointShuffle’),
dict(type=‘DefaultFormatBundle3D’, class_names=class_names),
dict(type=‘Collect3D’, keys=[‘points’, ‘gt_bboxes_3d’, ‘gt_labels_3d’])
]

你可能感兴趣的:(代码阅读,深度学习)