【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

这篇作为点云深度学习的开山之作,给出了一种直接用点云作为数据输入,用于分类、分割等任务的神经网络框架。该网络尽管结构简单,但是高效且有效。

论文下载https://arxiv.org/abs/1612.00593

代码下载:https://github.com/fxia22/pointnet.pytorch

在处理点云数据之前,首先需要对点云的性质进行一定的探究。点云具有一下三个性质:1、无序性(unorder)2、点与点之间具有相关性(interaction among points)3、变换的不变性(Invariance under transformations)。下面对这三个性质进行解释。

  • 无序性:

点云是一组没有特定顺序的点的集合,所以当输入为N个点时,对于这N个点所有N!个排列顺序而言,输出结果应该保持不变。(置换不变性)

  • 点与点之间具有相关性:

点云数据来自于一个由距离度量的空间,这就意味着点和点之间并不是独立的,点周围的邻域点也是有意义的。

  • 变换的不变性

对点云数据进行平移或者旋转后,网络学习到的内容不应在识别全局点云的类别或是点云的分割结果产生影响。

论文简介

一、motivation:

点云的数据和RGB图像的数据不同,图像数据是一种规则的数据,而点云数据是一种不规则的数据。传统卷积的结构只能处理高度规则的数据,尽管可以用一些方法来使得点云的数据变得regular,但是这种数据转换的方式或多或少得会给数据本身带来一定的误差。同时,与图像不同的是,点云是一种几何的数据结构,所以在3D感知的任务中会更加的有效。

二、aiming:

是否有一种简单高效的网络用来直接处理原始的点云数据?

三、contribution:

• 设计了一种新颖的深度网络架构,适用于在三维中使用无序的点集;

• 展示了如何训练这样的网络来执行3D形状分类、形状分割等任务;

• 对方法的稳定性和效率进行了深入的实证和理论分析;

• 阐述了网络中所选神经元计算的3D特征,并对其性能进行了直观的解释。

四、The architecture of pointnet

网络有三个关键模块:

1、the max pooling layer,该层的作用为一个对称函数用来聚合所有点的信息;

2、全局信息和局部信息结合的结构,用来分割时,每个点的特征同时感知局部和全局的信息。

3、两个联合对齐网络分别用来对齐点的位置空间和特征空间。

五、Theoretical Analysis

  • Universal approximation:该公式表征了设计的网络结构对连续的网络集有着很好的逼近能力

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第1张图片
  • Bottleneck dimension and stability:

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第2张图片

五、framework:

图(1)为pointnet的整体框图:

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第3张图片

图(1)PointNet Architecture

这个图上半部分是Classification Network,下半部分是Segmentation Network。

论文解读

Q1、针对于点云的特性,有什么样的解决方案,文中又采用什么样的解决方法呢?

首先对于无序性,有三个方案可以选择,文中采用的是symmetry function的方式。实现方法如图(2)所示:1、给出一个特定的顺序(sorting)。这种方法的缺点在于在高维空间里很难找到一个稳定的排列顺序。2、利用RNN来学习所有的排列顺序(sequential model)。这种方法的缺点在于当点集较大时,RNN很难有很好的鲁棒性。3、使用一个对称函数来聚合每个点的信息(symmetry function)。这种方法更加的简洁且易于实现。文中使用的是MAX()这个对称函数【max pooling】

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第4张图片

图(2)Three approaches to achieve order invariance

其次对于变换无序性的解决方法是加入了一个Joint Alignment Network。旨在特征提取前将所有的输入对齐到一个规范的空间,进而达到点集学习到的表示对变换无关的目的。文中是通过T-net预测出旋转矩阵,T-net的结构等于一个小型的pointnet,在特征空间上,也使用了Alignment Network。这里需要注意一个问题,特征空间的变换矩阵有着更高的维度,导致优化难,所以引入了正则项:(However, transformation matrix in the feature space has much higher dimension than the spatial transform matrix, which greatly increases the difficulty of optimization. We therefore add a regularization term to our softmax training loss. We constrain the feature transformation matrix to be close to orthogonal matrix)

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第5张图片

这里的A就是对齐矩阵,很显然A是一个正交阵,点云数据乘以正交阵后不会导致信息的丢失。

Q2、Classification Network和Segmentation Network为什么要这么设计?

这里我们需要对分类任务和分割任务有一定的了解,直观的来看,分类任务是输入一组点,输出是这组点的类别。分割任务是输入一组点,输出的是对这组点的形状分割。事实上,分类是对一组点的全局表征,而分割是预测每个点表征(也就是对每个点的分类),所以分割和分类都属于分类问题。对于Classification Network而言,通过maxpooling,得到全局特征后,就可以用全连接层得到分类结果。对于Segmentation Network而言,因为需要每个点的分类,则不仅仅需要全局信息,也需要局部信息。因此通过将每个点的特征和全局特征concat起来,经过全连接层得到分割结果

Q3、Pointnet这个网络学到的是什么?

这里我引用文中的话,Intuitively, our network learns to summarize a shape by a sparse set of key points.In experiment section we see that the key points form the skeleton of an object. 简单的来讲,网络学习到是一些稀疏的关键点,通过可视化的结果图(3)来看,这些关键点是这些物体的骨架或者边缘。这也可以证明网络鲁棒性好的原因,在点丢失50%时,仍可以得到好的结果。

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第6张图片

图(3)Critical points的可视化结果

Pytorch实现

代码示例:

1、位置空间对齐模块

class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        # 卷积层
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        # 全连接层
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()
        # bn
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
  
        x = torch.max(x, 2, keepdim=True)[0]
        '''x.view(-1, 1024):'
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

2、特征空间对齐模块

class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

3、特征编码模块

'''globel feature'''
class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

4、Classification Network

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat

5、Segmentation Network

class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

你可能感兴趣的:(3D点云,深度学习,计算机视觉)