3D点云目标检测算法Pointnet++项目实战 Pytorch实现

        刚刚复现完成PointNet++分类和分割网络,效果还不错,分享给大家。

        Pointnet++算法的原理在此不再赘述,本文专注讲一下重要代码,从输入数据到输出结果展现复现过程。注:复现的是MSG特征拼接方法,MRG的代码作者没有公布出来,代码里还有SSG的,也就是没有做半径特征拼接的原始方法,自己也可以跑一下。

        附代码。代码里对重要的语句都进行了详细的注释喔~

链接:https://pan.baidu.com/s/10Nk4Zd3S_NklY5PJwzmnWA 
提取码:6688 
 

目录

一、配置环境

二、项目文件概述

三、参数设置

四、Pointnet++分类网络模型代码讲解

1、pointnet2_cls_msg.py

2、pointnet_util.py

(1)采样

(2)分组

(3)提取特征

(4)特征连接

 五、复现结果展示

 六、配置环境时遇到的问题总结

 七、结语


一、配置环境

(1)Windows系统

python 3.8   cuda 11.1 pytorch 1.8.0  torchvision 0.9.0

        一开始总是运行不起来,各种报错,python是3.6的,重新下了3.8的就好了,运行之前一定检查好版本,版本匹配不好的话就运行不起来喔~

(2)ubuntu系统

python 3.7 cuda 11.1  pytorch 1.8.0  torchvision 0.9.0

二、项目文件概述

1、pointnet2_cls_msg.py    

        模型文件

2、pointnet_util.py    

        其中的PointNetSetAbstractionMsg() 是核心,用来提取特征,除此以外最远点采样等方法都在这个文件中。    

3、ModelNetDataLoader.py    

       用来取数据    

4、train_cls.py  

       参数设置、创建文件夹、加载数据和模型、进行训练  

5、test_cls.py

       用来测试,可以换成自己的数据集跑

三、参数设置

train_cls.py

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第1张图片

1、论文里batchsize是24,这里设置的是8,对准确率影响不大。

2、模型是MSG的,epoch为200。

3、normal意思是法向量,可以自己设置,要不要使用法向量,使用的话初始输入的点云数据除了3个位置信息x,y,z以外还有三个法向量Nx,Ny,Nz,每个点一共是6个特征。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第2张图片

注:每个epoch里面都是先训练再测试,先输出Train  Instance Accuracy——训练时所有单个物体目标的平均准确率,再输出测试结果,测试结果有四个,这里解释一下每个结果的含义:

(1)Instance Accuracy 所有单个物体目标的平均准确率

(2)Class Accuracy 整个类别的平均准确率

(3)Best Instance Accuracy 所有单体目标检测最高准确率

(4)Best Class Accuracy 类别的最高准确率

四、Pointnet++分类网络模型代码讲解

1、pointnet2_cls_msg.py

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第3张图片

        一共有3层,每一层都用来进行点云数据的特征提取,分别是sa1,sa2,sa3。

        第一维特征是采样点的个数,第二维特征是半径大小,一共设置了3个——0.1,0.2,0.4,第三维特征是对应半径的组里的点的个数,第四维特征是输入的点的特征个数或者说是通道大小,第五到七维特征是三个半径分别对应的特征维数。eg.从32维卷积到32维,再卷积到64维。

 3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第4张图片

         sa1的输出结果作为sa2的输入,sa2的输出结果作为sa3的,sa3的输出结果输入进3个全连接层,最后通过sofmax函数输出物体分类。

2、pointnet_util.py

        下面讲一下PointNetSetAbstractionMsg(nn.Module),代码核心!

        pointnet++模型大致分为三步——采样、分组、提取特征,然后再循环两次之后输入到全连接网络。如下图所示,我已经在图里标注出了每一步对应的具体语句。下面按照这个顺序具体讲一下每个步骤。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第5张图片

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第6张图片

(1)采样

        pointnet++模型首先进行采样,从1024个点里采样512个,再从512个点里采样128个,最后把这128个点当做一个组进行特征提取。作者采用最远点采样方法进行采样,通过farthest_point_sample(xyz, npoint)实现(第63行),下面是代码。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第7张图片

       首先初始化8*512的矩阵,一共有8个batch,每个batch里有512个中心点,最终要返回的。再定义一个8*1024的距离矩阵,里面存储的是除了中心点以外的每个点距离当前所有已采样点的最小距离。batch里每个样本随机初始化一个最远点的索引(第一个点是随机选择的)。进入循环,开始采样512个中心点。

       现在已经随机选择了第一个点,第二点为距离第一个最远的点。第三个点的确定比较麻烦,要计算除了这两个点以外的所有点与这两个点之间的距离,每个点都有两个距离,选择这两个距离里最小的的距离作为该点到中心点的距离,一共有1022个距离,再从这1022个距离里面选择最大的那个距离,该距离对应的点就是第三个点。

(2)分组

        第二步分组通过query_ball_point(radius, nsample, xyz, new_xyz)完成,在代码第87行。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第8张图片

       在分组的时候,遵循的原则就是要使得每个组里的点的个数一致。因此,当组里的点的个数大于给定值的时候,就删除点;小于的时候就复制点,增加点的个数。作者具体的做法如下:

       找到距离大于给定半径的设置成一个N值(1024)索引,值为1024表示这个点不在组当中。然后做升序排序,排序个数为规定的圆圈里的个数就行了。后面的都是大的值(1024)。可能有很多点到中心点的距离都小于半径,而我们只需要16个,所以排序一下,取前16个离中心点最近的点。如果半径内的点没那么多,则复制离中心点最近的那个点即第一个点来代替值为1024的点,使得组里的点的个数增加为规定的个数。

(3)提取特征

        pointnet_util.py

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第9张图片

        分好组以后,进行卷积操作,每种半径各3次。之后做max操作,这个是本篇论文的创新点,解决了点云的无序性。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第10张图片

        我画了一个图,更形象地解释一下卷积,其实poinetnet++的卷积操作与2维图像的卷积类似,得益于作者将点云分组的想法,这样就有了区域的概念,而pointnet算法就没有区域的概念,缺少点与点之间的联系。以第一次为例,从1024个原始点里采样了512个点,半径为0.1的组里设置了16个点,每个点有6个特征信息,在进行完卷积之后,每个点从6个特征变为了32个特征,通过max操作,从每个维度中选取最大的数作为这一个组的特征,最终输出512个点,每个点是32维特征。其他两种半径也是这样,不再赘述。

 (4)特征连接

         每种半径都输出不同的特征维度,但是点的个数都是512个。0.1半径对应的特征维度是64;0.2半径对应的特征维度是128;0.4半径对应的特征维度是128。这里运用的是MSG的特征拼接方法,因此要把每个半径对应的特征拼接起来在输入到sa2中,运用pytorch的cat函数。拼接完成后,sa1输出512个采样点,每个点的特征个数是64+128+128=320,sa2再从512个点中采样128个中心点。

 五、复现结果展示

         从logs文件中截的图。数据集为modelnet40,训练数据9843条,测试数据2468条。一共是40个分类。

         测试准确率为92.96%,pointnet++的源代码后续又进行了优化,所以结果会高于原论文的结果。下图是原论文的实验结果。

3D点云目标检测算法Pointnet++项目实战 Pytorch实现_第11张图片

六、配置环境时遇到的问题总结

1、python安装包报错:PackagesNotFoundError: The following packages are not available from current channels

anaconda search -t conda xxx #查询包所存在的版本
anaconda show 123 #123为想下载的包的名字
conda install --channel https://。。。 xxx  #https://。。为下载链接

 七、结语

        刚刚入门点云目标检测,还有很多需要学习的,如果有不恰当的地方还请各位大神批评指正。

        这是自己成功复现的第一篇论文,也是写的第一篇文章,希望有人看呐!嘿嘿。希望能帮助到也在研究点云的小伙伴们~

        感觉记录下复现的过程和学习笔记还是很有必要的。继续努力!

你可能感兴趣的:(自动驾驶,算法,pycharm)