【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation




在处理点云数据之前,首先需要对点云的性质进行一定的探究。点云具有一下三个性质:1、无序性(unorder)2、点与点之间具有相关性(interaction among points)3、变换的不变性(Invariance under transformations)。下面对这三个性质进行解释。

• 设计了一种新颖的深度网络架构,适用于在三维中使用无序的点集;

• 展示了如何训练这样的网络来执行3D形状分类、形状分割等任务;

• 对方法的稳定性和效率进行了深入的实证和理论分析;

• 阐述了网络中所选神经元计算的3D特征,并对其性能进行了直观的解释。

四、The architecture of pointnet


1、the max pooling layer,该层的作用为一个对称函数用来聚合所有点的信息;



五、Theoretical Analysis

  • Universal approximation:该公式表征了设计的网络结构对连续的网络集有着很好的逼近能力

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第1张图片
  • Bottleneck dimension and stability:

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第2张图片



【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第3张图片

图(1)PointNet Architecture

这个图上半部分是Classification Network,下半部分是Segmentation Network。



首先对于无序性,有三个方案可以选择,文中采用的是symmetry function的方式。实现方法如图(2)所示:1、给出一个特定的顺序(sorting)。这种方法的缺点在于在高维空间里很难找到一个稳定的排列顺序。2、利用RNN来学习所有的排列顺序(sequential model)。这种方法的缺点在于当点集较大时,RNN很难有很好的鲁棒性。3、使用一个对称函数来聚合每个点的信息(symmetry function)。这种方法更加的简洁且易于实现。文中使用的是MAX()这个对称函数【max pooling】

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第4张图片

图(2)Three approaches to achieve order invariance

其次对于变换无序性的解决方法是加入了一个Joint Alignment Network。旨在特征提取前将所有的输入对齐到一个规范的空间,进而达到点集学习到的表示对变换无关的目的。文中是通过T-net预测出旋转矩阵,T-net的结构等于一个小型的pointnet,在特征空间上,也使用了Alignment Network。这里需要注意一个问题,特征空间的变换矩阵有着更高的维度,导致优化难,所以引入了正则项:(However, transformation matrix in the feature space has much higher dimension than the spatial transform matrix, which greatly increases the difficulty of optimization. We therefore add a regularization term to our softmax training loss. We constrain the feature transformation matrix to be close to orthogonal matrix)

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第5张图片


Q2、Classification Network和Segmentation Network为什么要这么设计?

这里我们需要对分类任务和分割任务有一定的了解,直观的来看,分类任务是输入一组点,输出是这组点的类别。分割任务是输入一组点,输出的是对这组点的形状分割。事实上,分类是对一组点的全局表征,而分割是预测每个点表征(也就是对每个点的分类),所以分割和分类都属于分类问题。对于Classification Network而言,通过maxpooling,得到全局特征后,就可以用全连接层得到分类结果。对于Segmentation Network而言,因为需要每个点的分类,则不仅仅需要全局信息,也需要局部信息。因此通过将每个点的特征和全局特征concat起来,经过全连接层得到分割结果


这里我引用文中的话,Intuitively, our network learns to summarize a shape by a sparse set of key points.In experiment section we see that the key points form the skeleton of an object. 简单的来讲,网络学习到是一些稀疏的关键点,通过可视化的结果图(3)来看,这些关键点是这些物体的骨架或者边缘。这也可以证明网络鲁棒性好的原因,在点丢失50%时,仍可以得到好的结果。

【论文简析+解读+Pytorch实现】PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation_第6张图片

图(3)Critical points的可视化结果




class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        # 卷积层
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        # 全连接层
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()
        # bn
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        '''x.view(-1, 1024):'
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x


'''globel feature'''
class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

4、Classification Network

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat

5、Segmentation Network

class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat
