PointNet代码学习(pytorch版本)

源码地址

pointnet.pytorch
感谢大神!

代码结构

(pytorch) s@s:~/pointnet.pytorch$ tree -d
.
├── misc
├── pointnet
│   └── __pycache__
├── scripts
├── shapenetcore_partanno_segmentation_benchmark_v0
│   ├── 02691156
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02773838
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02954340
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02958343
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03001627
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03261776
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03467517
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03624134
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03636649
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03642806
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03790512
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03797390
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03948459
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04099429
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04225987
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04379243
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   └── train_test_split
└── utils
    ├── cls
    ├── __pycache__
    └── seg

74 directories

utils

├── cls
│   ├── cls_model_0.pth
│   └── cls_model_1.pth
├── point_test.pts
├── __pycache__
│   ├── show3d_balls.cpython-36.pyc
│   ├── show3d_balls.cpython-37.pyc
│   └── show_seg.cpython-36.pyc
├── render_balls_so.cpp
├── render_balls_so.so
├── seg
│   ├── seg_model_Chair_0.pth
│   ├── seg_model_Chair_1.pth
│   ├── seg_model_Chair_2.pth
│   ├── seg_model_Chair_3.pth
│   └── seg_model_Chair_4.pth
├── show3d_balls.py
├── show_cls.py
├── show_points.py
├── show_seg.py
├── train_classification.py
└── train_segmentation.py

3 directories, 19 files

简介

  • cls和seg文件夹下的是模型;
  • train_classification.py和train_segmentation.py是训练脚本
  • show_seg.py和 show_cls.py是导入模型进行测试并可视化的脚本
  • show3d_balls.py是可视化脚本,含有可视化相关的函数
  • show_points.py是自己写的测试的脚本(可以无视)
  • render_balls_so.cpp和render_balls_so.so是可视化渲染相关的库
  • point_test.pts是点云文件,自己测试用(不要理会)

代码详细注释

  • show_seg.py
'''
对原始点云进行分割,并可视化
例:python show_seg.py --model seg/seg_model_Chair_1.pth 
						--class_choice Airplane --idx 2
'''

from __future__ import print_function
from show3d_balls import showpoints
import argparse
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable

#先把系统文件夹调到pointnet.pytorch下,防止找不到pointnet这个文件夹,
import sys
sys.path.append('/home/s/pointnet.pytorch')

from pointnet.dataset import ShapeNetDataset
from pointnet.model import PointNetDenseCls
import matplotlib.pyplot as plt




# 命令行解析
parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--idx', type=int, default=0, help='model index')
parser.add_argument('--dataset', type=str, default='', help='dataset path')
parser.add_argument('--class_choice', type=str, default='', help='class choice')


# 输出一行状态栏参数如下:
# 				Namespace(class_choice='Airplane', dataset='', 
# 							idx=2, model='seg/seg_model_Chair_1.pth')
opt = parser.parse_args()		
# print("opt:{}".format(opt))

# 数据预处理,得到某一类模型的集合
d = ShapeNetDataset(
#    root=opt.dataset,
	root='/home/s/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0',
    class_choice=[opt.class_choice],  #选择哪一类模型
    split='test',
    data_augmentation=0)

idx = opt.idx
# print("d:{}".format(d))
print("model %d/%d" % (idx, len(d)))  #d代表全部的飞机数量
# model 2/341

# print('dir(d):{}'.format(dir(d)))
# print('d[idx]:{}'.format(d[idx]))


point, seg = d[idx]  #模型里的第idx的点云
print(point.size(), seg.size())     #seg代表每一个点的标签
# torch.Size([2500, 3]) torch.Size([2500])

point_np = point.numpy() #将torch转为numpy

# 可视化
cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
gt = cmap[seg.numpy() - 1, :]

# 载入模型
state_dict = torch.load(opt.model)  
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0])
classifier.load_state_dict(state_dict)
classifier.eval()  #评估

# 点云转置
point = point.transpose(1, 0).contiguous()
print('point.transpose(1, 0).shape: ',point.shape)

point = Variable(point.view(1, point.size()[0], point.size()[1]))
print('--------------------')
print(point.dtype)
pred, _, _ = classifier(point)   #分割
# print('\npred.shape:',pred[0],'\n')
pred_choice = pred.data.max(2)[1]
print(pred_choice.numpy())   #输出每一个点的预测类别

# print(pred_choice.size())
print(pred_choice.numpy()[0])  #[1 1 1 ... 1 1 1]
pred_color = cmap[pred_choice.numpy()[0], :]   #根据分类结果显示颜色
print('\npred_color: ',pred_color.shape,'\n')


showpoints(point_np, gt, pred_color)  #pred_colord的为(2500, 3)的矩阵
print(point_np.shape)
  • show_cls.py
from __future__ import print_function  #使用python3的print函数
import argparse
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable

import sys
sys.path.append('/home/s/pointnet.pytorch')

from pointnet.dataset import ShapeNetDataset  #ShapeNetDataset是一个类 下面会实例化这个类
from pointnet.model import PointNetCls   #读入模型
import torch.nn.functional as F


#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default = '',  help='model path')
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')


opt = parser.parse_args()
print(opt)


# 实例化ShapeNetDataset类
test_dataset = ShapeNetDataset(
#    root='shapenetcore_partanno_segmentation_benchmark_v0',
	root='/home/s/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0',
    split='test',
    classification=True,
    npoints=opt.num_points,
    data_augmentation=False)

# 读入测试数据
testdataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=True)


# 导入模型
classifier = PointNetCls(k=len(test_dataset.classes))
classifier.cuda()
classifier.load_state_dict(torch.load(opt.model))
classifier.eval()


for i, data in enumerate(testdataloader, 0):
    points, target = data
    points, target = Variable(points), Variable(target[:, 0])
    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    pred, _, _ = classifier(points)  #进行分类
    loss = F.nll_loss(pred, target)  #计算损失函数

# 计算准确率
    pred_choice = pred.data.max(1)[1]
    correct = pred_choice.eq(target.data).cpu().sum()
    print('i:%d  loss: %f accuracy: %f' % (i, loss.data.item(), correct / float(32))) 

  • show_points.py
'''
自己写的,用来测试
可视化文件夹下的点云数据
输入:n*3的矩阵
'''

from __future__ import print_function
from show3d_balls import showpoints
import argparse
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
from pointnet.model import PointNetDenseCls

# with open('./point_test.pts') as file:
#     for line in file:
#         print(len(line))

import matplotlib.pylab as plt
import sys

# from utils.show_seg import seg

sys.path.append('/home/s/pointnet.pytorch')

# points=np.loadtxt('./point_test.pts')
points=np.loadtxt('./point_test.pts',dtype=np.float32)  #预测只能输入float32的格式的数据

print(points.shape)

# 可视化
cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
# gt = cmap[seg.numpy() - 1, :]

# 可视化点云
showpoints(points)

#采样到2500个点
choice = np.random.choice(len(points), 2500, replace=True)
# print('choice:{}'.format(choice))
points = points[choice, :]
print('points[choice, :]:{}'.format(points))
point_np=points



# 载入模型
state_dict = torch.load('./seg/seg_model_Chair_1.pth')
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0])
classifier.load_state_dict(state_dict)
classifier.eval()  #设置为评估状态

# 点云转置

points=torch.from_numpy(points)
print(points.shape)
point = points.transpose(1, 0).contiguous()
print('point.transpose(1, 0).shape: ',point.shape)

point = Variable(point.view(1, point.size()[0], point.size()[1]))  #转为torch变量1,3,2500
print('--------------------')

print(point.dtype)
# point=torch.tensor(point,dtype=torch.float32)
pred, _, _ = classifier(point)   #分割
print(pred)

pred_choice = pred.data.max(2)[1]
print(pred_choice.numpy())   #输出每一个点的预测类别

# print(pred_choice.size())
print(pred_choice.numpy()[0])  #[1 1 1 ... 1 1 1]
pred_color = cmap[pred_choice.numpy()[0], :]   #根据分类结果显示颜色
print('\npred_color: ',pred_color.shape,'\n')
print(pred_color.dtype)

# point_np=point.numpy().reshape(2500,3)
print(point_np.shape)
print(point_np.dtype)
showpoints(point_np, pred_color,pred_color)  #pred_colord的为(2500, 3)的矩阵
  • show3d_balls.py
'''
点云数据可视化
'''

import numpy as np
import ctypes as ct
import cv2
import sys
showsz = 800
mousex, mousey = 0.5, 0.5
zoom = 1.0
changed = True

def onmouse(*args):
    global mousex, mousey, changed
    y = args[1]
    x = args[2]
    mousex = x / float(showsz)
    mousey = y / float(showsz)
    changed = True

cv2.namedWindow('show3d')
cv2.moveWindow('show3d', 0, 0)
cv2.setMouseCallback('show3d', onmouse)

dll = np.ctypeslib.load_library('render_balls_so', '.')

# 该函数输入为n*3的矩阵
def showpoints(xyz,c_gt=None, c_pred = None, waittime=0, 
    showrot=False, magnifyBlue=0, freezerot=False, background=(0,0,0), 
    normalizecolor=True, ballradius=10):
    global showsz, mousex, mousey, zoom, changed
    xyz=xyz-xyz.mean(axis=0)
    radius=((xyz**2).sum(axis=-1)**0.5).max()
    xyz/=(radius*2.2)/showsz
    if c_gt is None:
        c0 = np.zeros((len(xyz), ), dtype='float32') + 255
        c1 = np.zeros((len(xyz), ), dtype='float32') + 255
        c2 = np.zeros((len(xyz), ), dtype='float32') + 255
    else:
        c0 = c_gt[:, 0]
        c1 = c_gt[:, 1]
        c2 = c_gt[:, 2]


    if normalizecolor:
        c0 /= (c0.max() + 1e-14) / 255.0
        c1 /= (c1.max() + 1e-14) / 255.0
        c2 /= (c2.max() + 1e-14) / 255.0


    c0 = np.require(c0, 'float32', 'C')
    c1 = np.require(c1, 'float32', 'C')
    c2 = np.require(c2, 'float32', 'C')

    show = np.zeros((showsz, showsz, 3), dtype='uint8')
    def render():
        rotmat=np.eye(3)
        if not freezerot:
            xangle=(mousey-0.5)*np.pi*1.2
        else:
            xangle=0
        rotmat = rotmat.dot(
            np.array([
                [1.0, 0.0, 0.0],
                [0.0, np.cos(xangle), -np.sin(xangle)],
                [0.0, np.sin(xangle), np.cos(xangle)],
            ]))
        if not freezerot:
            yangle = (mousex - 0.5) * np.pi * 1.2
        else:
            yangle = 0
        rotmat = rotmat.dot(
            np.array([
                [np.cos(yangle), 0.0, -np.sin(yangle)],
                [0.0, 1.0, 0.0],
                [np.sin(yangle), 0.0, np.cos(yangle)],
            ]))
        rotmat *= zoom
        nxyz = xyz.dot(rotmat) + [showsz / 2, showsz / 2, 0]

        ixyz = nxyz.astype('int32')
        show[:] = background
        dll.render_ball(
            ct.c_int(show.shape[0]), ct.c_int(show.shape[1]),
            show.ctypes.data_as(ct.c_void_p), ct.c_int(ixyz.shape[0]),
            ixyz.ctypes.data_as(ct.c_void_p), c0.ctypes.data_as(ct.c_void_p),
            c1.ctypes.data_as(ct.c_void_p), c2.ctypes.data_as(ct.c_void_p),
            ct.c_int(ballradius))

        if magnifyBlue > 0:
            show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(
                show[:, :, 0], 1, axis=0))
            if magnifyBlue >= 2:
                show[:, :, 0] = np.maximum(show[:, :, 0],
                                           np.roll(show[:, :, 0], -1, axis=0))
            show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(
                show[:, :, 0], 1, axis=1))
            if magnifyBlue >= 2:
                show[:, :, 0] = np.maximum(show[:, :, 0],
                                           np.roll(show[:, :, 0], -1, axis=1))
        if showrot:
            cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)),
                        (30, showsz - 30), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
            cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)),
                        (30, showsz - 50), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
            cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0,
                        0.5, cv2.cv.CV_RGB(255, 0, 0))
    changed = True
    while True:
        if changed:
            render()
            changed = False
        cv2.imshow('show3d', show)
        if waittime == 0:
            cmd = cv2.waitKey(10) % 256
        else:
            cmd = cv2.waitKey(waittime) % 256
        if cmd == ord('q'):
            break
        elif cmd == ord('Q'):
            sys.exit(0)

        if cmd == ord('t') or cmd == ord('p'):
            if cmd == ord('t'):
                if c_gt is None:
                    c0 = np.zeros((len(xyz), ), dtype='float32') + 255
                    c1 = np.zeros((len(xyz), ), dtype='float32') + 255
                    c2 = np.zeros((len(xyz), ), dtype='float32') + 255
                else:
                    c0 = c_gt[:, 0]
                    c1 = c_gt[:, 1]
                    c2 = c_gt[:, 2]
            else:
                if c_pred is None:
                    c0 = np.zeros((len(xyz), ), dtype='float32') + 255
                    c1 = np.zeros((len(xyz), ), dtype='float32') + 255
                    c2 = np.zeros((len(xyz), ), dtype='float32') + 255
                else:
                    c0 = c_pred[:, 0]
                    c1 = c_pred[:, 1]
                    c2 = c_pred[:, 2]
            if normalizecolor:
                c0 /= (c0.max() + 1e-14) / 255.0
                c1 /= (c1.max() + 1e-14) / 255.0
                c2 /= (c2.max() + 1e-14) / 255.0
            c0 = np.require(c0, 'float32', 'C')
            c1 = np.require(c1, 'float32', 'C')
            c2 = np.require(c2, 'float32', 'C')
            changed = True

        if cmd==ord('n'):
            zoom*=1.1
            changed=True
        elif cmd==ord('m'):
            zoom/=1.1
            changed=True
        elif cmd==ord('r'):
            zoom=1.0
            changed=True
        elif cmd==ord('s'):
            cv2.imwrite('show3d.png',show)
        if waittime!=0:
            break
    return cmd

if __name__ == '__main__':
    np.random.seed(100)
    showpoints(np.random.randn(2500, 3))

  • train_classification.py
  • train_segmentation.py
'''
训练分割的网络模型
'''
# from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data

import sys
sys.path.append('/home/s/pointnet.pytorch')

from pointnet.dataset import ShapeNetDataset
from pointnet.model import PointNetDenseCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument(
    '--batchSize', type=int, default=32, help='input batch size')#每一次输入的32个元素,训练一遍需要输入 Size/batchSize 次 
parser.add_argument(
    '--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument(
    '--nepoch', type=int, default=25, help='number of epochs to train for')#对整个数据集训练25次
parser.add_argument('--outf', type=str, default='seg', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice")
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")

opt = parser.parse_args()
# print(opt)

#使产生的随机数是确定的 保证结果是可以复现的
opt.manualSeed = random.randint(1, 10000)  # fix seed 产生一个随机点
# print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)   #本函数设置随机数的类型  没有返回值
torch.manual_seed(opt.manualSeed)

dataset = ShapeNetDataset(
    root=opt.dataset,
    classification=False,
    class_choice=[opt.class_choice])
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=opt.batchSize,
    shuffle=True,
    num_workers=int(opt.workers))

test_dataset = ShapeNetDataset(
    root=opt.dataset,
    classification=False,
    class_choice=[opt.class_choice],
    split='test',
    data_augmentation=False)
testdataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=opt.batchSize,
    shuffle=True,
    num_workers=int(opt.workers))

# print(len(dataset), len(test_dataset))  #1958  341
num_classes = dataset.num_seg_classes   #num_seg_classes是一个数字,代表该类物体应该被分成几类,由class_choice在终端输入后根据字典找出来的
# print('-----classes', num_classes)  # 4
try:
    os.makedirs(opt.outf)
except OSError:
    pass

blue = lambda x: '\033[94m' + x + '\033[0m'  #设置显示的颜色

classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform)   #分割网络

if opt.model != '':   #如果模型存在就导入
    classifier.load_state_dict(torch.load(opt.model))

optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # scheduler.step()调用step_size次,学习率才会调整一次
classifier.cuda()

num_batch = len(dataset) / opt.batchSize   #计算需要几个batch才能导入dataset

for epoch in range(opt.nepoch):
    # scheduler.step()   #更新一下学习率 每step_size=20调整一次
    # print('----------lr:{}'.format(classifier.optimizer.state_dict()['param_groups'][0]['lr'] ) )
    for i, data in enumerate(dataloader, 0):  #i 代表第几个batch(最大值:num_batch),每个batch有batcSize个点云,每个点云2500个点

        # print('data:{}'.format(data))

        points, target = data
        # print('points.shape:{}'.format(points.size))
        # print('target.shape:{}'.format(target.size))

        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()   #梯度归零
        classifier = classifier.train()  #模型设置为训练模式
        pred, trans, trans_feat = classifier(points)
        pred = pred.view(-1, num_classes)   #num_classes代表应该被分割成多少类  torch.Size([20000, 4])
        target = target.view(-1, 1)[:, 0] - 1

        # print(pred.size(), target.size())  #输出pred和target的格式  torch.Size([20000, 4]) torch.Size([20000])


        loss = F.nll_loss(pred, target)  #计算损失
        if opt.feature_transform:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        loss.backward()  #反向传播损失
        optimizer.step()    #梯度下降优化 以batch为单位
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        print('[%d/%d: %d/%d] train loss: %f accuracy: %f' % (epoch,(opt.nepoch), i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500)))

        if i % 10 == 0:  #每10个batch执行一次,即为每10*batchSize个点云执行一次验证
            j, data = next(enumerate(testdataloader, 0))
            points, target = data
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            classifier = classifier.eval()
            pred, _, _ = classifier(points)
            pred = pred.view(-1, num_classes)
            target = target.view(-1, 1)[:, 0] - 1
            loss = F.nll_loss(pred, target)
            # print('pred.shape:{}'.format(pred.size))
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()
            print('[%d/%d: %d/%d] %s loss: %f accuracy: %f' % (epoch,(opt.nepoch), i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))

    scheduler.step()   #更新一下学习率 每step_size=20调整一次

    torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch))   #每个epoch都保存一个模型


# 计算精确度
## benchmark mIOU
shape_ious = []
for i,data in tqdm(enumerate(testdataloader, 0)):  # tqdm进度条
    points, target = data
    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    classifier = classifier.eval()   #评估
    pred, _, _ = classifier(points)
    print('\npred.shape:{}'.format(pred.shape))  #pred.shape:torch.Size([8, 2500, 4])
    pred_choice = pred.data.max(2)[1]

    pred_np = pred_choice.cpu().data.numpy()
    target_np = target.cpu().data.numpy() - 1
    print('target_np.shape:{}\n'.format(target_np.shape))  #target_np.shape:(8, 2500)


    for shape_idx in range(target_np.shape[0]):
        parts = range(num_classes)#np.unique(target_np[shape_idx])
        part_ious = []
        for part in parts:
            I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
            U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
            if U == 0:
                iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1
            else:
                iou = I / float(U)
            part_ious.append(iou)
        shape_ious.append(np.mean(part_ious))

print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious)))

pointnet

├── dataset.py
├── __init__.py
├── model.py
└── __pycache__
    ├── dataset.cpython-36.pyc
    ├── __init__.cpython-36.pyc
    └── model.cpython-36.pyc

1 directory, 6 files

简介

  • dataset.py数据集导入相关
  • model.py模型构建相关

代码详细注释

  • dataset.py
# from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import sys
from tqdm import tqdm 
import json
from plyfile import PlyData, PlyElement

def get_segmentation_classes(root):
    catfile = os.path.join(root, 'synsetoffset2category.txt')
    cat = {}
    meta = {}

    with open(catfile, 'r') as f:
        for line in f:
            ls = line.strip().split()
            cat[ls[0]] = ls[1]

    for item in cat:
        dir_seg = os.path.join(root, cat[item], 'points_label')
        dir_point = os.path.join(root, cat[item], 'points')
        fns = sorted(os.listdir(dir_point))
        meta[item] = []
        for fn in fns:
            token = (os.path.splitext(os.path.basename(fn))[0])
            meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
    
    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:
        for item in cat:
            datapath = []
            num_seg_classes = 0
            for fn in meta[item]:
                datapath.append((item, fn[0], fn[1]))

            for i in tqdm(range(len(datapath))):
                l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8)))
                if l > num_seg_classes:
                    num_seg_classes = l

            print("category {} num segmentation classes {}".format(item, num_seg_classes))
            f.write("{}\t{}\n".format(item, num_seg_classes))

def gen_modelnet_id(root):
    classes = []
    with open(os.path.join(root, 'train.txt'), 'r') as f:
        for line in f:
            classes.append(line.strip().split('/')[0])
    classes = np.unique(classes)
    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:
        for i in range(len(classes)):
            f.write('{}\t{}\n'.format(classes[i], i))


# 导入ShapeNetDataset数据集的类
class ShapeNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 classification=False,
                 class_choice=None,
                 # split='test',
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')

        # #输出catfile的名称
        # print('\n------------------------\nself.catfile:{}'.format(self.catfile),'\n-------------------------')

        self.cat = {}  #字典
        self.data_augmentation = data_augmentation  #默认1
        self.classification = classification  #默认0
        self.seg_classes = {}   #字典

        #读取synsetoffset2category.txt
        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()  #ls是一个list. ls: ['Airplane', '02691156']
                # print("------------------")
                # print('ls:',ls)
                # print("------------------")
                self.cat[ls[0]] = ls[1]

        #显示这个字典cat{}
        # print('self.cat:',self.cat)  #self.cat: {'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243'}

        # print('self.cat.items() :', self.cat.items())   #cat.items()返回一个list[],元素为元组().即[(),(),()....]
        if not class_choice is None:  #当class_choice不是None的时候才会执行
            self.cat = {k: v for k, v in self.cat.items() if k in class_choice}  #self.cat:{'Airplane': '02691156'}
        # print('-------------------------')
        # print('self.cat:{}'.format(self.cat))
        # print('-------------------------')

        self.id2cat = {v: k for k, v in self.cat.items()}
        # print('self.id2cat:{}'.format(self.id2cat))  #self.id2cat:{'02691156': 'Airplane'}

        self.meta = {} #字典
        splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))   #读取shuffled_train_file_list.json文件
        # print('splitfile:{}'.format(splitfile))

        #from IPython import embed; embed()
        filelist = json.load(open(splitfile, 'r')) #读取shuffled_train_file_list.json文件
        for item in self.cat:  #self.cat:{'Airplane': '02691156'}
            # print('item:{}'.format(item))  #item:Airplane
            self.meta[item] = []
        # print('self.meta:{}'.format(self.meta))  #self.meta:{'Airplane': []}

        # print('self.cat.values():{}'.format(self.cat.values()))
        for file in filelist:
            # print('file:{}'.format(file))
            _, category, uuid = file.split('/')  #获取每一条的shape_data/04379243/9e3f1901ea14aca753315facdf531a34

            if category in self.cat.values():
                # print("category:{}".format(category))   #04379243

                self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
                                        os.path.join(self.root, category, 'points_label', uuid+'.seg')))
        # print("self.meta:{}".format(self.meta))   #{'Airplane':[(),(),()....() ] }
        self.datapath = []  #list
        for item in self.cat:  #self.cat:{'Airplane': '02691156'}
            # print('item:{}'.format(item))  #item:Airplane
            for fn in self.meta[item]:
                self.datapath.append((item, fn[0], fn[1]))   #datapath:[(),()...],其中()为('Airplane','.pts','.seg')元组

        # print('----------------------')
        # print("datapath:{}".format(self.datapath))
        # print('----------------((------')
        # print('cat:{}'.format(self.cat))  #cat:{'Airplane': '02691156'}

        # print('len(self.cat):{}'.format(len(self.cat)))  #len(self.cat):1
        self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))  #使用zip函数, 把key和value的list组合在一起, 再转成字典(dict).
        # print(self.classes) #{'Airplane': 0}
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
            for line in f:
                ls = line.strip().split()   #ls是一个list[]
                self.seg_classes[ls[0]] = int(ls[1]) #seg_classes是一个字典{}  ->seg_classes:{'Airplane': 4, 'Bag': 2, 'Cap': 2, 'Car': 4, 'Chair': 4, 'Earphone': 3, 'Guitar': 3, 'Knife': 2, 'Lamp': 4, 'Laptop': 2, 'Motorbike': 6, 'Mug': 2, 'Pistol': 3, 'Rocket': 3, 'Skateboard': 3, 'Table': 3}

        # print('seg_classes:{}'.format(self.seg_classes))
        # print('self.cat.keys():{}'.format(self.cat.keys()))  #self.cat.keys():dict_keys(['Airplane'])
        # print('list(self.cat.keys())[0]:{}'.format(list(self.cat.keys())[0]))  #转化为一个list->list(self.cat.keys()):['Airplane']
        self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]   #num_seg_classes为对应的的类应该分成几类

        # 输出:{'Airplane': 4, 'Bag': 2, 'Cap': 2, 'Car': 4, 'Chair': 4, 'Earphone': 3, 'Guitar': 3, 'Knife': 2, 'Lamp': 4, 'Laptop': 2, 'Motorbike': 6, 'Mug': 2, 'Pistol': 3, 'Rocket': 3, 'Skateboard': 3, 'Table': 3} 4
        print(self.seg_classes, list(self.cat.keys())[0],self.num_seg_classes)  #输出:一个字典{} 一个数字

    def __getitem__(self, index):
        fn = self.datapath[index]  #fn为元组()  ()为('Airplane','.pts','.seg')元组
        cls = self.classes[self.datapath[index][0]] #cls为classes{}的value
        # print('cls:{}'.format(cls))

        #读取点云和分类标签
        point_set = np.loadtxt(fn[1]).astype(np.float32)
        seg = np.loadtxt(fn[2]).astype(np.int64)
        # print('point_set.shape:{}_____ seg.shape:{}'.format(point_set.shape, seg.shape))  #point_set.shape:(2658, 3)_____ seg.shape:(2658,)

# 重新采样到self.npoints个点
        choice = np.random.choice(len(seg), self.npoints, replace=True)
        #resample
        point_set = point_set[choice, :]

        # print('point_set:{}'.format(point_set))
        # print('np.mean(point_set, axis = 0):{}'.format(np.mean(point_set, axis = 0)))  #中心点 1*3
        point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # 去中心化
        # print('new_point_set:{}'.format(point_set))

        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)  #计算到原点的最远距离
        # print('dist:{}'.format(dist))
        point_set = point_set / dist #scale  归一化

        if self.data_augmentation:  #默认False  开启旋转任意角度并加上一个bias
            theta = np.random.uniform(0,np.pi*2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
            # print('rotation_matrix:{}'.format(rotation_matrix))
            point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
            # print('point_set:{}'.format(point_set.shape))
            point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

        seg = seg[choice]
        point_set = torch.from_numpy(point_set)
        seg = torch.from_numpy(seg)
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))  #cls为对应的代号,比如Airplane对应0



        if self.classification:  #classification默认是0
            return point_set, cls
        else:
            return point_set, seg

    def __len__(self):
        # print('len(self.datapath):{}'.format(len(self.datapath)))
        return len(self.datapath)

class ModelNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.split = split
        self.data_augmentation = data_augmentation
        self.fns = []
        with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:
            for line in f:
                self.fns.append(line.strip())

        self.cat = {}
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = int(ls[1])

        print(self.cat)
        self.classes = list(self.cat.keys())

    def __getitem__(self, index):
        fn = self.fns[index]
        cls = self.cat[fn.split('/')[0]]
        with open(os.path.join(self.root, fn), 'rb') as f:
            plydata = PlyData.read(f)
        pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
        choice = np.random.choice(len(pts), self.npoints, replace=True)
        point_set = pts[choice, :]

        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
        point_set = point_set / dist  # scale

        if self.data_augmentation:
            theta = np.random.uniform(0, np.pi * 2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter

        point_set = torch.from_numpy(point_set.astype(np.float32))
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))
        return point_set, cls


    def __len__(self):
        return len(self.fns)

if __name__ == '__main__':
    dataset = sys.argv[1]
    datapath = sys.argv[2]

    if dataset == 'shapenet':
        d = ShapeNetDataset(root = datapath, class_choice = ['Bag'])
        print('len(d):{}'.format(len(d)))
        ps, seg = d[0]
        print(ps.size(), ps.type(), seg.size(),seg.type())

        print('--------------------------------------------------------')
        d = ShapeNetDataset(root = datapath, classification = True)
        print('len(d):{}'.format(len(d)))
        ps, cls = d[0]
        print(ps.size(), ps.type(), cls.size(),cls.type())
        # get_segmentation_classes(datapath)

    if dataset == 'modelnet':
        gen_modelnet_id(datapath)
        d = ModelNetDataset(root=datapath)
        print(len(d))
        print(d[0])

  • model.py

  • 包含以下结构
    • 模型
      • STN3d
      • STNkd
      • PointNetfeat
      • PointNetCls ---- 分类网络
      • PointNetDenseCls ------分割网络
    • 函数
      • feature_transform_regularizer
# from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F


class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat


class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)

        # print('x:{}'.format(x.shape))  #x:torch.Size([32, 3, 2500])
        x = x.transpose(1,2).contiguous()   #transpose()交换1和2两个维度    contiguous()用来将tensor变为连续的
        # print('x.transpose(2,1):{}'.format(x.shape))  #x.transpose(2,1):torch.Size([32, 2500, 3])


        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        # print('log_softmax:{}'.format(x.shape))  #torch.Size([20000, 4])
        x = x.view(batchsize, n_pts, self.k)
        # print('x.shape:{}'.format(x.shape))
        return x, trans, trans_feat

def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]   # eye:产生一个单位矩阵
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

if __name__ == '__main__':
    sim_data = Variable(torch.rand(32,3,2500))
    trans = STN3d()
    out = trans(sim_data)
    print('stn', out.size())
    print('loss', feature_transform_regularizer(out))

    sim_data_64d = Variable(torch.rand(32, 64, 2500))
    trans = STNkd(k=64)
    out = trans(sim_data_64d)
    print('stn64d', out.size())
    print('loss', feature_transform_regularizer(out))

    pointfeat = PointNetfeat(global_feat=True)
    out, _, _ = pointfeat(sim_data)
    print('global feat', out.size())

    pointfeat = PointNetfeat(global_feat=False)
    out, _, _ = pointfeat(sim_data)
    print('point feat', out.size())

    cls = PointNetCls(k = 5)
    out, _, _ = cls(sim_data)
    print('class', out.size())

    seg = PointNetDenseCls(k = 3)
    out, _, _ = seg(sim_data)
    print('seg', out.size())

    print(seg)

Reference

  • 言简意赅python系列—if not x: 和 if x is not None: 和 if not x is None: 的区别
  • Python 字典(Dictionary) items()方法
  • Python - 两个列表(list)组成字典(dict)
  • Python.__getitem__方法
  • python四个魔法方法__len__,getitem,setitem,delitem

开源框架PointNet 代码详解——/pointnet/sem_seg/train.py
【3D计算机视觉】从PointNet到PointNet++理论及pytorch代码

Tools

  • 每次5秒更新一次gpu状态
watch -n 5 nvidia-smi 

你可能感兴趣的:(深度学习,PointCloud)