点云数据集可视化

ModelNet40数据集可视化:


import open3d as o3d
import numpy as np
from plyfile import PlyData
from PIL import Image


# 读取 PLY 文件内容
def read_ply_file(file_path):
    try:
        ply_data = PlyData.read(file_path)
        print("PLY 文件内容成功读取")
        return ply_data
    except Exception as e:
        print(f"读取 PLY 文件时出错: {e}")
        return None


# 将 PlyData 转换为 Open3D 的点云数据结构
def convert_to_open3d_point_cloud(ply_data):
    if ply_data:
        vertex_data = ply_data['vertex'].data
        points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
        point_cloud = o3d.geometry.PointCloud()
        point_cloud.points = o3d.utility.Vector3dVector(points)
        return point_cloud
    else:
        return None


# 将点云绕 x, y, z 轴旋转
def rotate_point_cloud(point_cloud, angle_x, angle_y, angle_z):
    R_x = np.array([
        [1, 0, 0],
        [0, np.cos(angle_x), -np.sin(angle_x)],
        [0, np.sin(angle_x), np.cos(angle_x)]
    ])

    R_y = np.array([
        [np.cos(angle_y), 0, np.sin(angle_y)],
        [0, 1, 0],
        [-np.sin(angle_y), 0, np.cos(angle_y)]
    ])

    R_z = np.array([
        [np.cos(angle_z), -np.sin(angle_z), 0],
        [np.sin(angle_z), np.cos(angle_z), 0],
        [0, 0, 1]
    ])

    R = R_z @ R_y @ R_x
    points = np.asarray(point_cloud.points)
    rotated_points = points @ R.T
    point_cloud.points = o3d.utility.Vector3dVector(rotated_points)


# 可视化点云数据并保存当前视角图片
def visualize_and_save_point_cloud(point_cloud, save_path):
    if point_cloud:
        vis = o3d.visualization.Visualizer()
        vis.create_window()
        vis.add_geometry(point_cloud)

        vis.poll_events()
        vis.update_renderer()

        # 获取当前视角图像
        image = vis.capture_screen_float_buffer(do_render=True)
        image = np.asarray(image)

        # 保存图像
        image_pil = Image.fromarray((image * 255).astype(np.uint8))
        image_pil.save(save_path)

        vis.destroy_window()
        print(f"当前视角图片已保存到 {save_path}")
    else:
        print("点云数据为空,无法可视化")
# 设置点云颜色为蓝色
def set_point_cloud_color(point_cloud, color=[0, 0, 1]):
    colors = np.tile(color, (len(point_cloud.points), 1))
    point_cloud.colors = o3d.utility.Vector3dVector(colors)

# 示例:加载、设置颜色、旋转点云、可视化并保存图片
file_path = 'C:/Users/26394/Desktop/论文画图/modelnet40可视化/radio.ply'
save_path = 'C:/Users/26394/Desktop/论文画图/modelnet40可视化/radio.png'

ply_data = read_ply_file(file_path)
point_cloud = convert_to_open3d_point_cloud(ply_data)
if point_cloud:
    # 设置点云颜色为蓝色
    set_point_cloud_color(point_cloud, color=[0, 0, 1])

    # 旋转点云 45 度(绕 x 轴,y 轴,z 轴)
    rotate_point_cloud(point_cloud, angle_x=np.deg2rad(45), angle_y=np.deg2rad(45), angle_z=np.deg2rad(45))

    # 可视化并保存图像
    visualize_and_save_point_cloud(point_cloud, save_path)
else:
    print("点云数据加载失败,无法继续")

shapenet数据集可视化

import os
import json
import warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from mpl_toolkits.mplot3d import Axes3D

warnings.filterwarnings('ignore')


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc


class PartNormalDataset(Dataset):
    def __init__(self, root, npoints=2500, split='train', class_choice=None, normal_channel=False):
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
        self.cat = {}
        self.normal_channel = normal_channel

        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
        self.cat = {k: v for k, v in self.cat.items()}
        self.classes_original = dict(zip(self.cat, range(len(self.cat))))

        if class_choice is not None:
            self.cat = {k: v for k, v in self.cat.items() if k in class_choice}

        self.meta = {}
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
            train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])

        for item in self.cat:
            self.meta[item] = []
            dir_point = os.path.join(self.root, self.cat[item])
            fns = sorted(os.listdir(dir_point))
            if split == 'trainval':
                fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
            elif split == 'train':
                fns = [fn for fn in fns if fn[0:-4] in train_ids]
            elif split == 'val':
                fns = [fn for fn in fns if fn[0:-4] in val_ids]
            elif split == 'test':
                fns = [fn for fn in fns if fn[0:-4] in test_ids]
            else:
                print('Unknown split: %s. Exiting..' % (split))
                exit(-1)

            for fn in fns:
                token = (os.path.splitext(os.path.basename(fn))[0])
                self.meta[item].append(os.path.join(dir_point, token + '.txt'))

        self.datapath = []
        for item in self.cat:
            for fn in self.meta[item]:
                self.datapath.append((item, fn))

        self.classes = {}
        for i in self.cat.keys():
            self.classes[i] = self.classes_original[i]

        self.seg_classes = {
            'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
            'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
            'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
            'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
            'Chair': [12, 13, 14, 15], 'Knife': [22, 23]
        }

        self.cache = {}
        self.cache_size = 20000

    def __getitem__(self, index):
        if index in self.cache:
            point_set, cls, seg = self.cache[index]
        else:
            fn = self.datapath[index]
            cat = self.datapath[index][0]
            cls = self.classes[cat]
            cls = np.array([cls]).astype(np.int32)
            data = np.loadtxt(fn[1]).astype(np.float32)
            if not self.normal_channel:
                point_set = data[:, 0:3]
            else:
                point_set = data[:, 0:6]
            seg = data[:, -1].astype(np.int32)
            if len(self.cache) < self.cache_size:
                self.cache[index] = (point_set, cls, seg)
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

        choice = np.random.choice(len(seg), self.npoints, replace=True)
        point_set = point_set[choice, :]
        seg = seg[choice]

        return point_set, cls, seg

    def __len__(self):
        return len(self.datapath)


def plot_point_cloud(points, title, save_path=None):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=points[:, 2], cmap='Spectral', s=1)
    ax.set_title(title, fontsize=20)
    ax.axis('off')
    if save_path:
        plt.savefig(save_path)
    plt.show()


if __name__ == '__main__':
    root = r'C:/Users/26394/Desktop/all/课题任务/点云/data/shapenetcore_partanno_segmentation_benchmark_v0_normal'
    dataset = PartNormalDataset(root=root, split='train', npoints=2048)

    for cls_name in dataset.cat.keys():
        print(f"Visualizing class: {cls_name}")
        for i, (point_set, cls, seg) in enumerate(DataLoader(dataset, batch_size=1, shuffle=True)):
            if cls.item() == dataset.classes[cls_name]:
                plot_point_cloud(point_set[0].numpy(), title=cls_name, save_path=f'{cls_name}.png')
                break

你可能感兴趣的:(python,点云,数据可视化)