MMDETECTION3D 使用kitti格式的数据集跑centerpoint模型

MMDETECTION3D 使用kitti格式的数据集跑centerpoint模型

1 修改配置文件configs/centerpoint/centerpoint_pillar02_kitti_3d.py

如下

# """configs/centerpoint/centerpoint_pillar02_kitti_3d.py"""
_base_ = [
    '../_base_/datasets/centerpoint_kitii_3d_3class.py',
    '../_base_/models/centerpoint_pillar02_second_secfpn_kitti.py',
    '../_base_/schedules/cyclic-20e.py', '../_base_/default_runtime.py'
]

# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
# Using calibration info convert the Lidar-coordinate point cloud range to the
# ego-coordinate point cloud range could bring a little promotion in nuScenes.
# point_cloud_range = [-51.2, -52, -5.0, 51.2, 50.4, 3.0]
# For nuScenes we usually do 10-class detection
# class_names = [
#     'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
#     'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
# ]

# class_names = [
#     'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
#     'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
# ]
class_names = ['Pedestrian', 'Cyclist', 'Car']
# data_prefix = dict(pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP')
model = dict(
    data_preprocessor=dict(
        voxel_layer=dict(point_cloud_range=point_cloud_range)),
    pts_voxel_encoder=dict(point_cloud_range=point_cloud_range),
    pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])),
    # model training and testing settings
    train_cfg=dict(pts=dict(point_cloud_range=point_cloud_range)),
    test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2])))

# dataset_type = 'NuScenesDataset'
# data_root = 'data/nuscenes/'
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
backend_args = None

db_sampler = dict(
    data_root=data_root,
    # info_path=data_root + 'nuscenes_dbinfos_train.pkl',
    info_path=data_root + 'kitti_dbinfos_train.pkl',
    rate=1.0,
    # prepare=dict(
    #     filter_by_difficulty=[-1],
    #     filter_by_min_points=dict(
    #         car=5,
    #         truck=5,
    #         bus=5,
    #         trailer=5,
    #         construction_vehicle=5,
    #         traffic_cone=5,
    #         barrier=5,
    #         motorcycle=5,
    #         bicycle=5,
    #         pedestrian=5)),
    prepare=dict(
        filter_by_difficulty=[-1],
        filter_by_min_points=dict(
            car=5,
            Cyclist = 5,
            pedestrian=5)),
    classes=class_names,
    sample_groups=dict(
            car=2,
            Cyclist = 4,
            pedestrian=4),
    points_loader=dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        # load_dim=5,
        load_dim=4,
        # use_dim=[0, 1, 2, 3, 4],
        backend_args=backend_args),
    backend_args=backend_args)

train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        # load_dim=5,
        # use_dim=5,
        load_dim=4,
        use_dim=4,
        backend_args=backend_args),
    # dict(
    #     type='LoadPointsFromMultiSweeps',
    #     sweeps_num=9,
    #     use_dim=[0, 1, 2, 3, 4],
    #     pad_empty_sweeps=True,
    #     remove_close=True,
    #     backend_args=backend_args),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    # dict(type='ObjectSample', db_sampler=db_sampler),
    # dict(
    #     type='GlobalRotScaleTrans',
    #     rot_range=[-0.3925, 0.3925],
    #     scale_ratio_range=[0.95, 1.05],
    #     translation_std=[0, 0, 0]),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.78539816, 0.78539816],
        scale_ratio_range=[0.95, 1.05]),
    dict(
        type='RandomFlip3D',
        # sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        # flip_ratio_bev_vertical=0.5
        ),
    dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
    # dict(type='ObjectNameFilter', classes=class_names),
    dict(type='PointShuffle'),
    dict(
        type='Pack3DDetInputs',
        keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        # load_dim=5,
        # use_dim=5,
        load_dim=4,
        use_dim=4,
        backend_args=backend_args),
    # dict(
    #     type='LoadPointsFromMultiSweeps',
    #     sweeps_num=9,
    #     use_dim=[0, 1, 2, 3, 4],
    #     pad_empty_sweeps=True,
    #     remove_close=True,
    #     backend_args=backend_args),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1., 1.],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict( type='PointsRangeFilter', point_cloud_range=point_cloud_range)
        ]),
    dict(type='Pack3DDetInputs', keys=['points'])
]

# train_dataloader = dict(
#     _delete_=True,
#     batch_size=4,
#     num_workers=4,
#     persistent_workers=True,
#     sampler=dict(type='DefaultSampler', shuffle=True),
#     dataset=dict(
#         type='CBGSDataset',
#         dataset=dict(
#             type=dataset_type,
#             data_root=data_root,
#             # ann_file='nuscenes_infos_train.pkl',
#             ann_file='kitti_infos_train.pkl',
#             pipeline=train_pipeline,
#             metainfo=dict(classes=class_names),
#             test_mode=False,
#             data_prefix=data_prefix,
#             use_valid_flag=True,
#             # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
#             # and box_type_3d='Depth' in sunrgbd and scannet dataset.
#             box_type_3d='LiDAR',
#             backend_args=backend_args)))


train_dataloader = dict(
    dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=dict(classes=class_names))))


test_dataloader = dict(
    dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names)))
val_dataloader = dict(
    dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names)))


lr = 0.0001
epoch_num = 200
optim_wrapper = dict(optimizer=dict(lr=lr), clip_grad=dict(max_norm=35, norm_type=2))
param_scheduler = [
    dict(
        type='CosineAnnealingLR',
        T_max=epoch_num * 0.4,
        eta_min=lr * 10,
        begin=0,
        end=epoch_num * 0.4,
        by_epoch=True,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingLR',
        T_max=epoch_num * 0.6,
        eta_min=lr * 1e-4,
        begin=epoch_num * 0.4,
        end=epoch_num * 1,
        by_epoch=True,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingMomentum',
        T_max=epoch_num * 0.4,
        eta_min=0.85 / 0.95,
        begin=0,
        end=epoch_num * 0.4,
        by_epoch=True,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingMomentum',
        T_max=epoch_num * 0.6,
        eta_min=1,
        begin=epoch_num * 0.4,
        end=epoch_num * 1,
        convert_to_iter_based=True)
]

train_cfg = dict(by_epoch=True, max_epochs=epoch_num, val_interval=20)
val_cfg = dict()

# train_cfg = dict(val_interval=20)

2 根据报错一步一步修改

2.1 FileNotFoundError: [Errno 2]

(1)centerpoint_kitii_3d_3class.py 拷贝到 configs/_base_/datasets
(2)centerpoint_pillar02_second_secfpn_kitti.py 拷贝到 configs/_base_/models

如下文件

configs/base/datasets/ centerpoint_kitii_3d_3class.py

#########"""configs/_base_/datasets/centerpoint_kitii_3d_3class.py"""
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car', 'Truck', 'Bicycle']
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
input_modality = dict(use_lidar=True, use_camera=False)
metainfo = dict(classes=class_names)

# Example to use different file client
# Method 1: simply set the data root and let the file I/O module
# automatically infer from prefix (not support LMDB and Memcache yet)

# data_root = 's3://openmmlab/datasets/detection3d/kitti/'

# Method 2: Use backend_args, file_client_args in versions before 1.1.0
# backend_args = dict(
#     backend='petrel',
#     path_mapping=dict({
#         './data/': 's3://openmmlab/datasets/detection3d/',
#          'data/': 's3://openmmlab/datasets/detection3d/'
#      }))
backend_args = None

db_sampler = dict(
    data_root=data_root,
    info_path=data_root + 'kitti_dbinfos_train.pkl',
    rate=1.0,
    prepare=dict(
        filter_by_difficulty=[-1],
        filter_by_min_points=dict(Car=5, Truck=10, Bicycle=10)),
    classes=class_names,
    sample_groups=dict(Car=12, Truck=20, Bicycle=8),
    points_loader=dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=4,
        use_dim=4,
        backend_args=backend_args),
    backend_args=backend_args)

train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=4,  # x, y, z, intensity
        use_dim=4,
        backend_args=backend_args),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    #dict(type='ObjectSample', db_sampler=db_sampler),
    dict(
        type='ObjectNoise',
        num_try=100,
        translation_std=[1.0, 1.0, 0.5],
        global_rot_range=[0.0, 0.0],
        rot_range=[-0.78539816, 0.78539816]),
    dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.78539816, 0.78539816],
        scale_ratio_range=[0.95, 1.05]),
    dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='PointShuffle'),
    dict(
        type='Pack3DDetInputs',
        keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=4,
        use_dim=4,
        backend_args=backend_args),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1., 1.],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict(
                type='PointsRangeFilter', point_cloud_range=point_cloud_range)
        ]),
    dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=4,
        use_dim=4,
        backend_args=backend_args),
    dict(type='Pack3DDetInputs', keys=['points'])
]
train_dataloader = dict(
    batch_size=4,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type='RepeatDataset',
        times=2,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            ann_file='kitti_infos_train.pkl',
            data_prefix=dict(pts='training/velodyne_reduced'),
            pipeline=train_pipeline,
            modality=input_modality,
            test_mode=False,
            metainfo=metainfo,
            # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
            # and box_type_3d='Depth' in sunrgbd and scannet dataset.
            box_type_3d='LiDAR',
            backend_args=backend_args)))
val_dataloader = dict(
    batch_size=4,
    num_workers=1,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(pts='training/velodyne_reduced'),
        ann_file='kitti_infos_val.pkl',
        pipeline=test_pipeline,
        modality=input_modality,
        test_mode=True,
        metainfo=metainfo,
        box_type_3d='LiDAR',
        backend_args=backend_args))
test_dataloader = dict(
    batch_size=4,
    num_workers=1,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(pts='training/velodyne_reduced'),
        ann_file='kitti_infos_val.pkl',
        pipeline=test_pipeline,
        modality=input_modality,
        test_mode=True,
        metainfo=metainfo,
        box_type_3d='LiDAR',
        backend_args=backend_args))
val_evaluator = dict(
    type='KittiMetric',
    ann_file=data_root + 'kitti_infos_val.pkl',
    metric='bbox',
    backend_args=backend_args)
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

configs/base/models/ centerpoint_pillar02_second_secfpn_kitti.py

#############"""configs/_base_/models/centerpoint_pillar02_second_secfpn_kitti.py"""
voxel_size = [0.2, 0.2, 8]
model = dict(
    type='CenterPoint',
    data_preprocessor=dict(
        type='Det3DDataPreprocessor',
        voxel=True,
        voxel_layer=dict(
            max_num_points=20,
            voxel_size=voxel_size,
            max_voxels=(90000, 120000))),
    pts_voxel_encoder=dict(
        type='PillarFeatureNet',
        # in_channels=5,
        in_channels=4,
        feat_channels=[64],
        with_distance=False,
        voxel_size=(0.2, 0.2, 8),
        norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
        legacy=False),
    pts_middle_encoder=dict(
        type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)),
    pts_backbone=dict(
        type='SECOND',
        in_channels=64,
        out_channels=[64, 128, 256],
        layer_nums=[3, 5, 5],
        layer_strides=[2, 2, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        conv_cfg=dict(type='Conv2d', bias=False)),
    pts_neck=dict(
        type='SECONDFPN',
        in_channels=[64, 128, 256],
        out_channels=[128, 128, 128],
        upsample_strides=[0.5, 1, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=True),
    pts_bbox_head=dict(
        type='CenterHead',
        in_channels=sum([128, 128, 128]),
        tasks=[
            dict(num_class=1, class_names=['car']),
            dict(num_class=2, class_names=['truck', 'construction_vehicle']),
            dict(num_class=2, class_names=['bus', 'trailer']),
            dict(num_class=1, class_names=['barrier']),
            dict(num_class=2, class_names=['motorcycle', 'bicycle']),
            dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
        ],
        common_heads=dict(
            # reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),   vel为速度,没有速度值
            reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2)),
        share_conv_channel=64,
        bbox_coder=dict(
            type='CenterPointBBoxCoder',
            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            max_num=500,
            score_threshold=0.1,
            out_size_factor=4,
            voxel_size=voxel_size[:2],
            # code_size=9   输入信息维度只有xyzlwhr 7种
            code_size=7),
        separate_head=dict(
            type='SeparateHead', init_bias=-2.19, final_kernel=3),
        loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
        loss_bbox=dict(
            type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
        norm_bbox=True),
    # model training and testing settings
    train_cfg=dict(
        pts=dict(
            grid_size=[512, 512, 1],
            voxel_size=voxel_size,
            out_size_factor=4,
            dense_reg=1,
            gaussian_overlap=0.1,
            max_objs=500,
            min_radius=2,
            # code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])),        各维度权重,最后两位为速度
            code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
    test_cfg=dict(
        pts=dict(
            post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            max_per_img=500,
            max_pool_nms=False,
            min_radius=[4, 12, 10, 1, 0.85, 0.175],
            score_threshold=0.1,
            pc_range=[-51.2, -51.2],
            out_size_factor=4,
            voxel_size=voxel_size[:2],
            nms_type='rotate',
            pre_max_size=1000,
            post_max_size=83,
            nms_thr=0.2)))

2.2 ValueError

报错如下:

mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 568, in get_targets_single
    vx, vy = task_boxes[idx][k][7:]
ValueError: not enough values to unpack (expected 2, got 0)

修改:注释掉568行如下

567                    # TODO: support other outdoor dataset
568                    # vx, vy = task_boxes[idx][k][7:]
569                    rot = task_boxes[idx][k][6]
570                    box_dim = task_boxes[idx][k][3:6]

2.3 NameError

报错如下:

/mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 578, in get_targets_single
    vx.unsqueeze(0),
NameError: name 'vx' is not defined

修改:注释掉578、579行如下

573                    anno_box[new_idx] = torch.cat([
574                        center - torch.tensor([x, y], device=device),
575                        z.unsqueeze(0), box_dim,
576                        torch.sin(rot).unsqueeze(0),
577                        torch.cos(rot).unsqueeze(0),
578                        # vx.unsqueeze(0),
579                        # vy.unsqueeze(0)
580                    ])

2.4 RuntimeError

报错如下:

mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 573, in get_targets_single
    anno_box[new_idx] = torch.cat([
RuntimeError: The expanded size of the tensor (10) must match the existing size (8) at non-singleton dimension 0.  Target sizes: [10].  Tensor sizes: [8]

修改:anno_box[new_idx] = torch.cat([])改为 anno_elems = [],同时先修改了2.6部分

                   anno_elems = [
                        center - torch.tensor([x, y], device=device),
                        z.unsqueeze(0), box_dim,
                        torch.sin(rot).unsqueeze(0),
                        torch.cos(rot).unsqueeze(0)
                    ]
                    anno_box[new_idx] = torch.cat(anno_elems)

2.5 KeyError

报错如下:

/mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 649, in loss_by_feat
    preds_dict[0]['vel']),
KeyError: 'vel'

修改如下:

646            preds_dict[0]['anno_box'] = torch.cat(
647                (preds_dict[0]['reg'], preds_dict[0]['height'],
648                 preds_dict[0]['dim'], preds_dict[0]['rot'],
649                 preds_dict[0]['vel']),
650                dim=1)

改为:

            anno_box = [
                preds_dict[0]['reg'], preds_dict[0]['height'],
                preds_dict[0]['dim'], preds_dict[0]['rot']
            ]
            # Key assumed to exist for bbox annotations with 9 values
            if 'vel' in preds_dict[0]:
                anno_box.append(preds_dict[0]['vel'])
            preds_dict[0]['anno_box'] = torch.cat(anno_box, dim=1)

2.6 RuntimeError

报错如下:

/home/user/anaconda3/envs/mmlab/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
mmdetection3d/mmdet3d/models/detectors/base.py", line 75, in forward
    return self.loss(inputs, data_samples, **kwargs)
mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 274, in loss
    losses_pts = self.pts_bbox_head.loss(pts_feats, batch_data_samples,
mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 612, in loss
    losses = self.loss_by_feat(outs, batch_gt_instance_3d)
/mmdetection3d/mmdet3d/models/dense_heads/centerpoint_head.py", line 677, in loss_by_feat
    bbox_weights = mask * mask.new_tensor(code_weights)
RuntimeError: The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 2

修改:增加如下476行

473        grid_size = torch.tensor(self.train_cfg['grid_size']).to(device)
474        pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
475        voxel_size = torch.tensor(self.train_cfg['voxel_size'])
476        gt_annotation_num = len(self.train_cfg['code_weights'])
477
478        feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
479
480        # reorganize the gt_dict by tasks
481        task_masks = []

屏蔽如下511行,改为如下512行

506        for idx, task_head in enumerate(self.task_heads):
507            heatmap = gt_bboxes_3d.new_zeros(
508                (len(self.class_names[idx]), feature_map_size[1],
509                 feature_map_size[0]))
510
511            # anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), dtype=torch.float32)
512            anno_box = gt_bboxes_3d.new_zeros((max_objs, gt_annotation_num), dtype=torch.float32)
513            
514            ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
515            mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)

上述修改的过程中,主要是由于数据格式不一致,如由于kitti数据集没有速度分量,所以在配置文件及模型文件”mmdet3d/models/dense_heads/centerpoint_head.py“中修改关于读取Box标签信息的部分。

3 程序正常运行

4 模型推理可视化

################""" point_visual.py  """

from mmdet3d.apis import inference_detector, init_model
import open3d as o3d
import numpy as np
import os
import time


# 创建旋转矩阵
def rotation_matrix(axis, angle):
    c = np.cos(angle)
    s = np.sin(angle)
    x, y, z = axis
    return np.array([
        [x * x * (1 - c) + c, x * y * (1 - c) - z * s, x * z * (1 - c) + y * s],
        [x * y * (1 - c) + z * s, y * y * (1 - c) + c, y * z * (1 - c) - x * s],
        [x * z * (1 - c) - y * s, y * z * (1 - c) + x * s, z * z * (1 - c) + c]
    ])

# 输入长宽高返回角点
def pt2corner(pt):
    corners = []
    x = pt[0]
    y = pt[1]
    z = pt[2]
    corners.append(np.array([x/2, y/2, 0]))
    corners.append(np.array([-x/2, y/2, 0]))
    corners.append(np.array([x/2, -y/2, 0]))
    corners.append(np.array([-x/2, -y/2, 0]))
    corners.append(np.array([x/2, y/2, z]))
    corners.append(np.array([-x/2, y/2, z]))
    corners.append(np.array([x/2, -y/2, z]))
    corners.append(np.array([-x/2, -y/2, z]))

    return np.concatenate(corners).reshape(-1,3)

def create_lines(bboxes_3d):
    lines = []
    axis = np.array([0,0,1])
    for i in range(bboxes_3d.shape[0]):
        rot = rotation_matrix(axis,bboxes_3d[i, 6])     #计算旋转矩阵
        corners = pt2corner(bboxes_3d[i,3:6])
        print(corners.shape)
        rot_corners = np.matmul(rot,corners.T)
        for col in range(rot_corners.shape[1]):
            rot_corners[0,col] = rot_corners[0,col] + bboxes_3d[i, 0]
            rot_corners[1,col] = rot_corners[1,col] + bboxes_3d[i, 1]
            rot_corners[2,col] = rot_corners[2,col] + bboxes_3d[i, 2]
        
    return lines


config_file = 'work_dirs/centerpoint_pillar02_kitti_3d/centerpoint_pillar02_kitti_3d.py'
checkpoint_file = 'work_dirs/centerpoint_pillar02_kitti_3d/epoch_200.pth'
model = init_model(config_file, checkpoint_file)

# 指定要读取的 ".bin" 文件路径
folder_path = "./data/kitti/training/velodyne"

bin_files = [f for f in os.listdir(folder_path) if f.endswith('.bin')]
bin_files_sorted = sorted(bin_files)
# 创建 Open3D 窗口
o3d.visualization.draw_geometries([])

# 打开可视化窗口
vis = o3d.visualization.Visualizer()
vis.create_window()
vis.get_render_option().point_size = 1
opt = vis.get_render_option()
opt.background_color = np.asarray([0, 0, 0])

view_control = vis.get_view_control()
view_control.set_lookat([0, -50, 5])  # 将相机的焦点设置为 (x, y, z)
view_control.set_up([0, 0, 1])
view_control.set_zoom(2)

# 打开 ".bin" 文件并持续更新点云
while True:
    for bin_file in bin_files_sorted:
        file_path = os.path.join(folder_path, bin_file)
        #计算包围框
        result,data = inference_detector(model, file_path)
        print('bboxes_3d: ',result.pred_instances_3d.bboxes_3d)   # 目标框的坐标信息
        print('scores_3d: ',result.pred_instances_3d.scores_3d)   # 目标的置信度得分
        print('labels_3d: ',result.pred_instances_3d.labels_3d)   # 目标的类别标签

        
        bbox = result.pred_instances_3d.bboxes_3d.cpu().numpy()
        
        # 使用 NumPy 读取 ".bin" 文件
        data = np.fromfile(file_path, dtype=np.float32)
        # 将数据重新组织为点云的坐标
        point_cloud = data.reshape(-1, 4)[:, :3]

        # 创建 Open3D 的点云对象
        point_cloud_o3d = o3d.geometry.PointCloud()
        point_cloud_o3d.points = o3d.utility.Vector3dVector(point_cloud)

        # 清空窗口并添加新的点云
        vis.clear_geometries()
        vis.add_geometry(point_cloud_o3d)

        for i in range(bbox.shape[0]):
            b = o3d.geometry.OrientedBoundingBox()
            b.center = bbox[i][:3]
            b.extent = bbox[i][3:6]
            R = o3d.geometry.OrientedBoundingBox.get_rotation_matrix_from_xyz((0, 0, bbox[i][6]))
            b.rotate(R, b.center)
            b.color = [1,0,0]
            vis.add_geometry(b)
            

        view_control = vis.get_view_control()
        # view_control.set_lookat([10, 5, 0])  # 将相机的焦点设置为 (x, y, z)
        view_control.set_lookat([5, 5, 0]) 
        view_control.set_up([1, 0, 2])
        view_control.set_zoom(1.5)

        time.sleep(0.8)

        # 更新可视化窗口
        vis.poll_events()
        vis.update_renderer()
    



你可能感兴趣的:(雷达,点云数据,人工智能,大数据,python)