from models import pointnet2_cls_ssg
import os
import sys
import torch
import argparse
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models', 'log'))
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('Testing')
parser.add_argument('--use_cpu', action='store_true', default=True, help='use cpu mode')
parser.add_argument('--model', default='pointnet2_cls_ssg',help='model name [default: pointnet_cls]') # pointnet2_cls_ssg/pointnet_cls
parser.add_argument('--num_category', default=3, type=int, choices=[2, 3, 10, 40],help='training on ModelNet10/40')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
return parser.parse_args()
args = parse_args()
point_num = args.num_point
class_num = args.num_category
normal_channel = args.use_normals
model = pointnet2_cls_ssg.get_model(class_num, normal_channel)
if not args.use_cpu:
model = model.cuda()
model.eval()
if not args.use_cpu:
checkpoint = torch.load('best_model.pth')
else:
checkpoint = torch.load('best_model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
x = (torch.rand(1, 6, point_num) if normal_channel else torch.rand(1, 3, point_num))
if not args.use_cpu:
x = x.cuda()
traced_script_module = torch.jit.trace(model, x)
export_onnx_file = "cls.onnx"
torch.onnx.export(traced_script_module, x, export_onnx_file, opset_version=11)
# traced_script_module.save("cls.pt")
为了torch.onnx.export(traced_script_module, x, export_onnx_file, opset_version=11)函数正常执行,需要对pointnet2_utils.py文件进行修改。修改后的代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
# view_shape[1:] = [1] * (len(view_shape) - 1)
new_view_shape = [view_shape[0]] + [1] * (len(view_shape) - 1)
view_shape = new_view_shape
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint: int):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device