ModelNet40数据集可视化:
import open3d as o3d
import numpy as np
from plyfile import PlyData
from PIL import Image
# 读取 PLY 文件内容
def read_ply_file(file_path):
try:
ply_data = PlyData.read(file_path)
print("PLY 文件内容成功读取")
return ply_data
except Exception as e:
print(f"读取 PLY 文件时出错: {e}")
return None
# 将 PlyData 转换为 Open3D 的点云数据结构
def convert_to_open3d_point_cloud(ply_data):
if ply_data:
vertex_data = ply_data['vertex'].data
points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)
return point_cloud
else:
return None
# 将点云绕 x, y, z 轴旋转
def rotate_point_cloud(point_cloud, angle_x, angle_y, angle_z):
R_x = np.array([
[1, 0, 0],
[0, np.cos(angle_x), -np.sin(angle_x)],
[0, np.sin(angle_x), np.cos(angle_x)]
])
R_y = np.array([
[np.cos(angle_y), 0, np.sin(angle_y)],
[0, 1, 0],
[-np.sin(angle_y), 0, np.cos(angle_y)]
])
R_z = np.array([
[np.cos(angle_z), -np.sin(angle_z), 0],
[np.sin(angle_z), np.cos(angle_z), 0],
[0, 0, 1]
])
R = R_z @ R_y @ R_x
points = np.asarray(point_cloud.points)
rotated_points = points @ R.T
point_cloud.points = o3d.utility.Vector3dVector(rotated_points)
# 可视化点云数据并保存当前视角图片
def visualize_and_save_point_cloud(point_cloud, save_path):
if point_cloud:
vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(point_cloud)
vis.poll_events()
vis.update_renderer()
# 获取当前视角图像
image = vis.capture_screen_float_buffer(do_render=True)
image = np.asarray(image)
# 保存图像
image_pil = Image.fromarray((image * 255).astype(np.uint8))
image_pil.save(save_path)
vis.destroy_window()
print(f"当前视角图片已保存到 {save_path}")
else:
print("点云数据为空,无法可视化")
# 设置点云颜色为蓝色
def set_point_cloud_color(point_cloud, color=[0, 0, 1]):
colors = np.tile(color, (len(point_cloud.points), 1))
point_cloud.colors = o3d.utility.Vector3dVector(colors)
# 示例:加载、设置颜色、旋转点云、可视化并保存图片
file_path = 'C:/Users/26394/Desktop/论文画图/modelnet40可视化/radio.ply'
save_path = 'C:/Users/26394/Desktop/论文画图/modelnet40可视化/radio.png'
ply_data = read_ply_file(file_path)
point_cloud = convert_to_open3d_point_cloud(ply_data)
if point_cloud:
# 设置点云颜色为蓝色
set_point_cloud_color(point_cloud, color=[0, 0, 1])
# 旋转点云 45 度(绕 x 轴,y 轴,z 轴)
rotate_point_cloud(point_cloud, angle_x=np.deg2rad(45), angle_y=np.deg2rad(45), angle_z=np.deg2rad(45))
# 可视化并保存图像
visualize_and_save_point_cloud(point_cloud, save_path)
else:
print("点云数据加载失败,无法继续")
shapenet数据集可视化
import os
import json
import warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from mpl_toolkits.mplot3d import Axes3D
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
class PartNormalDataset(Dataset):
def __init__(self, root, npoints=2500, split='train', class_choice=None, normal_channel=False):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.normal_channel = normal_channel
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.cat = {k: v for k, v in self.cat.items()}
self.classes_original = dict(zip(self.cat, range(len(self.cat))))
if class_choice is not None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat:
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point))
if split == 'trainval':
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Unknown split: %s. Exiting..' % (split))
exit(-1)
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append(os.path.join(dir_point, token + '.txt'))
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn))
self.classes = {}
for i in self.cat.keys():
self.classes[i] = self.classes_original[i]
self.seg_classes = {
'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]
}
self.cache = {}
self.cache_size = 20000
def __getitem__(self, index):
if index in self.cache:
point_set, cls, seg = self.cache[index]
else:
fn = self.datapath[index]
cat = self.datapath[index][0]
cls = self.classes[cat]
cls = np.array([cls]).astype(np.int32)
data = np.loadtxt(fn[1]).astype(np.float32)
if not self.normal_channel:
point_set = data[:, 0:3]
else:
point_set = data[:, 0:6]
seg = data[:, -1].astype(np.int32)
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls, seg)
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
choice = np.random.choice(len(seg), self.npoints, replace=True)
point_set = point_set[choice, :]
seg = seg[choice]
return point_set, cls, seg
def __len__(self):
return len(self.datapath)
def plot_point_cloud(points, title, save_path=None):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=points[:, 2], cmap='Spectral', s=1)
ax.set_title(title, fontsize=20)
ax.axis('off')
if save_path:
plt.savefig(save_path)
plt.show()
if __name__ == '__main__':
root = r'C:/Users/26394/Desktop/all/课题任务/点云/data/shapenetcore_partanno_segmentation_benchmark_v0_normal'
dataset = PartNormalDataset(root=root, split='train', npoints=2048)
for cls_name in dataset.cat.keys():
print(f"Visualizing class: {cls_name}")
for i, (point_set, cls, seg) in enumerate(DataLoader(dataset, batch_size=1, shuffle=True)):
if cls.item() == dataset.classes[cls_name]:
plot_point_cloud(point_set[0].numpy(), title=cls_name, save_path=f'{cls_name}.png')
break