PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 论文和代码详解

PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 论文和代码详解

  • Paper
    • Abstract
    • 1. Introduction
    • 3. Method
      • 3.2 Hierarchical Point Set Feature Learning
        • Set Abstraction level
        • Sampling layer
        • Grouping layer
        • PointNet layer
      • 3.3 Robust Feature Learning under Non-Uniform Sampling Density
        • Multi-scale grouping (MSG)
        • Multi-resolution grouping (MRG)
      • 3.4 Point Feature Propagation for Set Segmentation
    • 4. Experiments
  • Code
    • 1. 功能函数组件
      • 1.1 普通函数
      • 1.2 网络层函数
        • Sampling layer + Grouping layer
        • Set Abstraction level
        • Feature Propagation level
    • 2. 任务函数
      • 2.1 Classification
        • single scale group classification
        • multi scale group classification
      • 2.2 Part Segmentation
        • Single Scale Part Segmentation
        • Multi Scale Part Segmentation
      • 2.3 Semantic Segmentation
        • Single Scale Semantic Segmentation
        • Multi Scale Semantic Segmentation
  • 如有错误,恳请大家指正

Paper

该文章发表于2017年的NIPS(Neural Information Processing Systems),文章链接:

PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

从文章的题目就可以看出这是对PointNet的改进,改进如下:

1. PointNet的基本思想是学习每个点的空间编码,然后将所有单独的点特征聚合为一个全局点云特征。根据它的设计,PointNet所捕获的局部特征是point-wise的,而不能捕获由度量引起的局部结构。PointNet++先利用Sampling layer和Grouping layer构建局部邻域,再利用PointNet layer编码(提取)局部特征。也就是文章中提到的Hierarchical Point Set Feature Learning
2. PointNet的采样密度是uniform的,而PointNet++可以处理non-uniform sampling。也就是文章中的Robust Feature Learning under Non-Uniform Sampling Density。

关于PointNet,可以参考:

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation 论文和代码详解

Abstract

PointNet的设计没有捕获由度量空间引起的局部结构,限制了它识别细粒度模式的能力和对复杂场景的泛化能力。

Hierarchical Point Set Feature Learning:在本文中,引入了一个层次神经网络,它递归地在输入点集的嵌套分区上应用PointNet。通过利用度量空间距离,PointNet++网络能够在不断增加的上下文尺度下学习局部特征。

Robust Feature Learning under Non-Uniform Sampling Density:点集通常以不同密度采样,这导致在均匀密度上训练的PointNet网络性能大大下降,本文提出了新的学习层,自适应结合来自多个尺度的特征。

1. Introduction

(1)构造点云网络时,应该考虑的问题:关于点云成员的排列顺序不变(无序性);由距离度量定义的局部邻域(局部特征);点集采样不均匀。

(2)PointNet++的general idea:首先,根据距离度量将点云划分成多个有重叠的局部区域(local region)。接下来,类似于CNN(Convolution Neural Network,卷积神经网络)从小邻域提取几何结构的局部特征;这些局部特征会被聚类到更大的单元并被处理以生成更高层的特征,重复该步骤直到得到了全部点集的特征。

(3) 设计PointNet++所面临的两个主要问题:1)如何根据距离度量将点集划分成有重叠的局部区域;2)如何通过局部特征学习器提取局部点集或局部特征。

(4)局部特征学习器:选择PointNet

(5)将点集划分成有重叠的区域:每一个划分用一个欧氏空间的 neighborhood ball 去定义,球的参数只有 centroid location(质心)和 scale(半径)。centroids的选取是利用最远点采样算法(farthest point sampling(FPS)algorithm);至于scale文章中没有说明用什么方法去选取,但是附加材料里的网络结构里直接给出了scale的值。

(6)Significant contribution:PointNet++ 多尺度地利用了邻域信息。

3. Method

PointNet++网络结构如下图所示:
PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 论文和代码详解_第1张图片
PointNet++其实是一个 Encoder-Decoder 结构,Encoder是下采样过程,通过 Hierarchical point set feature learning,即多个 set abstraction level,不断得到更大规模区域的特征。Decoder根据分类和分割应用,有所不同。

3.2 Hierarchical Point Set Feature Learning

(1)PointNet用了单个最大池化操作来整合整个点云;PoinNet++建立了一个点的层级聚类,并沿着分层逐步提取出越来越大的局部区域。

(2)PointNet++的层级结构是由一系列 set abstraction level 组成的,如结构图所示。

Set Abstraction level

一个 set abstraction level 由三个关键层组成:Sampling layer,Grouping layer,PointNet layer。它以 N × ( d + C ) N\times(d+C) N×(d+C) 矩阵作为输入,其中 N N N 是点的数目, d d d 是坐标的维度, C C C 是点的其它特征的维度;输出是 N ′ × ( d + C ′ ) N^{'}\times(d+C^{'}) N×(d+C),其中 N ′ N^{'} N 是下采样的点的数目, C ′ C^{'} C 是新的特征的维度。

Sampling layer

输入: 一堆点,即点集。

输出: 从输入点中选取的一部分点,即输入点集的子集。

目的: 从输入点中选取一部分点,选取的这部分点定义了 local regions 的 centroids。

方法: 给定输入点集 { x 1 , x 2 , ⋯   , x n } \{x_1,x_2,\cdots,x_n\} {x1,x2,,xn},采用迭代最远点采样来选择一部分点 { x i 1 , x i 2 , ⋯   , x i m } \{x_{i_1},x_{i_2},\cdots,x_{i_m}\} {xi1,xi2,,xim} ,使得 x i j x_{ij} xij 是距离 { x i 1 , x i 2 , ⋯   , x i j − 1 } \{x_{i_1},x_{i_2},\cdots,x_{i_{j-1}}\} {xi1,xi2,,xij1} 最远的点。最远点采样和随机采样相比,在相同数目centroid的情况,能够更好地覆盖整个点集。

  • 最远点采样算法如下:
    记初始点集为 A A A ,要选取的 m m m 个点所构成的集合为 B B B
    1,初始化 B = ∅ B=\emptyset B= 。随机从 A A A 中选取一个点作为初始点 x 1 x_1 x1 ,并将其加入到 B B B 中;
    2,对于 A / B A/B A/B 中的点 x x x ,计算其到 B B B 中每个点的距离,并将最小值作为 x x x B B B 的距离,即 d ( x , B ) = min ⁡ y ∈ B d ( x , y ) d(x,B)=\min_{y\in B}d(x,y) d(x,B)=minyBd(x,y) ,然后选取 A / B A/B A/B 中距离 B B B 最远的点加入 B B B 中,即 B = B ∪ { x ∣ max ⁡ x ∈ A / B min ⁡ y ∈ B d ( x , y ) } B=B\cup\{x|\max_{x\in A/B}\min_{y\in B}d(x,y)\} B=B{xmaxxA/BminyBd(x,y)} ;
    3,重复上述步骤,直到 ∣ B ∣ = m |B|=m B=m

注:从上述算法描述来看:这一层没有需要学习的参数(权重);输入的点集是带有特征的(虽然只用到了点的位置来计算距离,但是可以带有其它特征维度)并且没有对特征进行升维和降维。所以,这也是为什么在网络结构中把 Sampling layer 和 Grouping layer 合并到一起。

Grouping layer

输入:大小为 N × ( d + C ) N\times (d+C) N×(d+C) 的点集和大小为 N ′ × d N^{'}\times d N×d 的 centroids。其中 N ′ N^{'} N 是Sampling layer得到的centroid的数目。

输出:大小为 N ′ × K × ( d + C ) N^{'}\times K\times (d+C) N×K×(d+C) 的 groups of point sets( N ′ N^{'} N 组点集,每组点集中含有 K K K 个点)。

方法:在找 K K K 个最近点(距离对应的centroid最近)时,采用的是 Ball query,而非 k N N kNN kNN K K K nearest neighbor)。与 k N N kNN kNN 相比,Ball query 的局部邻域保证了一个固定大小的区域,因此使得局部区域特征在空间上更具有泛化性。也就是 Ball query 更加适合于应用在局部/细节识别的应用上。

Ball query算法:在球内按照距离找 K K K 个点,如果不满 K K K 个,则取第一个点补充。

scale的大小:文章中没有给出scale选择的标准,但是附加材料里直接给出了具体的值。

PointNet layer

输入:大小为 N ′ × K × ( d + C ) N^{'}\times K\times (d+C) N×K×(d+C) N ′ N^{'} N 个局部区域的点集。

输出:大小为 N ′ × ( d + C ′ ) N^{'}\times (d+C^{'}) N×(d+C)的数据。

目的:输出中的每个局部区域由其 centroid 和编码 centroid 邻域的局部特征抽象出来。即提取每个局部区域的特征,该特征由 centroid 的位置和其余特征编码而成。

方法:首先,一个局部区域的点集的坐标被变换到一个相对于 centroid 的局部坐标系(相对坐标): x i ( j ) = x i ( j ) − x ^ ( j ) f o r i = 1 , 2 , ⋯   , K x_i^{(j)}=x_i^{(j)}-\hat{x}^{(j)} for i=1,2,\cdots,K xi(j)=xi(j)x^(j)fori=1,2,,K and j = 1 , 2 , ⋯   , d j=1,2,\cdots,d j=1,2,,d ,其中 x ^ \hat{x} x^ 是 centroid 的坐标。然后,利用 PointNet 作为local pattern learning的基本构架。通过使用相对坐标和点特征来获取局部区域的 point-to-point 关系

3.3 Robust Feature Learning under Non-Uniform Sampling Density

点云分布不一致,即采样密度是non-uniform时,每个子区域中如果在生成的时候使用相同的半径,会导致有些区域采样点过少。

为了解决不同的采样密度问题,本文提出了 density adaptive PointNet layer,learn to combine features from regions of different scales when the input sampling density changes.。Hierarchical Point Set Feature Learning + density adaptive PointNet layer就构成了PointNet++

在3.2节,每个 set abstraction level 包含单个尺度的分组和特征提取。在pointnet++中,每个set abstraction level提取多个尺度的局部模式,并根据局部点密度进行智能组合。也就是说,3.2节讨论的是对于单个尺度的处理方式;3.3节是同时给多个尺度,每个尺度按照3.2节的方式处理,最后把这些尺度的特征结合起来。

在分组局部区域和结合不同尺度的特征方面,我们提出了以下两种类型的密度自适应层。示意图如下所示:
PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 论文和代码详解_第2张图片

Multi-scale grouping (MSG)

捕获多尺度模式的一种简单而有效的方法是应用具有不同尺度的分组层,然后根据PointNets提取每个尺度的特征。将不同尺度的特征串联起来形成一个多尺度特征(就是直接拼接在一起,不做任何额外的处理,相当于拼接两个vector)。也就是,对于每一层,给多个scale,然后对这些scale都利用set abstraction layer提取特征,然后把这些多个scale的特征拼接起来,就形成了所谓的multi-scale特征。然后将 multi-scale特征作为centroid的特征。

Multi-resolution grouping (MRG)

MSG方法计算成本很高,因为它对每个 centroid 在大规模的邻域上运行局部PointNet。特别是,由于centroid的数量在最低级别层通常是相当大的,时间成本是显著的。因此,采用的MRG的方式是通过级联不同维度的特征向量。如图b所示,左侧的特征是通过将底层每个子区域中经过 Set Abstraction level 获得的带有局部区域特征的中心点通过PointNet获得的特征,右侧是直接在底层局部区域的原始点上直接使用PointNet获得的特征,然后将两个特征进行级联。原文如下:In Fig. 3 (b), features of a region at some level L i L_i Li is a concatenation of two vectors. One vector (left in figure) is obtained by summarizing the features at each subregion from the lower level L i − 1 L_{i-1} Li1 using the set abstraction level. The other vector (right) is the feature that is obtained by directly processing all raw points in the local region using a single PointNet.

3.4 Point Feature Propagation for Set Segmentation

在 set abstraction layer,原始的点集被下采样。然而,在分割任务中,我们需要得到所有原始点的特征。一个简单的想法是在所有的 set abstraction layer 都保持所有的点作为 centroid。但是这样会造成巨大的计算代价。本文用了另一种方法:将特征从下采样点传播到原始点

本文采用了 distance based interpolationacross level skip links 的分层传播策略。在 feature propagation level,将 N l × ( d + C ) N_l\times(d+C) Nl×(d+C) 个点的特征传播到 N l − 1 × ( d + C ) N_{l-1}\times(d+C) Nl1×(d+C) 个点上,其中 N l − 1 N_{l-1} Nl1 N l N_l Nl 分别是第 l l lset abstraction level 的输入和输出的点的数目。

传播方法:对于第 l l l 层输入的 N l − 1 N_{l-1} Nl1 个点中的每一个点 x x x ,在输出的 N l N_l Nl 个点中找到 k k k 个最近的点,然后传播的特征 f f f 计算如下:
在这里插入图片描述其中本文默认的 p = 2 , k = 3 p=2,k=3 p=2k=3

4. Experiments

Datasets:一共四个数据集。ranging from 2D objects (MNIST), 3D objects (ModelNet40 rigid object, SHREC15 non-rigid object) to real 3D scenes (ScanNet)。

  • MNIST: Images of handwritten digits with 60k training and 10k testing samples.
  • ModelNet40: CAD models of 40 categories (mostly man-made). We use the official split with 9,843 shapes for training and 2,468 for testing.
  • SHREC15: 1200 shapes from 50 categories. Each category contains 24 shapes which are mostly organic ones with various poses such as horses, cats, etc. We use five fold cross validation to acquire classification accuracy on this dataset.
  • ScanNet: 1513 scanned and reconstructed indoor scenes. We follow the experiment setting in and use 1201 scenes for training, 312 scenes for test.

注:1. 和PointNet一样,仍然需要每个输入样本的采样点个数一样。2. 和PointNet一样,All point sets are normalized to be zero mean and within a unit ball.

Code

作者官方放出来的以TensorFlow为深度学习框架的代码链接:

PointNet++代码(使用TensorFlow)

他人重写的以Pytorch为深度学习框架的代码链接:

PointNet++代码(使用Pytorch)

注:上述以Pytorch为框架的代码中,其实同时包含了PointNet和PointNet++两者的代码。PointNet的代码以pointnet_为关键词;PointNet++的代码以pointnet2_为关键词。

由于PointNet的代码解释是以TensorFlow为框架的,所以对于PointNet++,改为以Pytorch为框架进行解释

PointNet++相较于PointNet多了一些其它重要的功能功能函数组件,因此先介绍这些函数组件

1. 功能函数组件

这些功能函数组件都在./models/文件下的pointnet2_utils.py中。

1.1 普通函数

# All point sets are normalized to be zero mean and within a unit ball
def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

# 计算两个点集之间的欧氏距离
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

# 已经知道采样点集对应输入点集的id,得到对应的坐标
def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

# 最远点采样
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

# Ball query
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    # sqrdists: [B, S, N] 记录S个中心点(new_xyz)与所有点(xyz)之间的欧几里德距离
    sqrdists = square_distance(new_xyz, xyz)
    # 找到所有距离大于radius^2的点,其group_idx直接置为N;其余的保留原来的值
    group_idx[sqrdists > radius ** 2] = N
    # 做升序排列,前面大于radius^2的都是N,会是最大值,所以直接在剩下的点中取出前nsample个点
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),
    # 这种点需要舍弃,直接用第一个点来代替即可
    # group_first: 实际就是把group_idx中的第一个点的值复制;为[B, S, K]的维度,便于后面的替换
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    # 找到group_idx中值等于N的点
    mask = group_idx == N
    # 将这些点的值替换为第一个点的值
    group_idx[mask] = group_first[mask]
    return group_idx

1.2 网络层函数

Sampling layer + Grouping layer

Sampling layer + Grouping layer主要用于将点云分成局部的group。下面有两个函数,分别是sample_and_group和sample_and_group_all,两者的区别是后者直接将所有点云看作一个group。

def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:最远点采样需要采样的数目
        radius:Ball query的scale
        nsample:每个Ball query中的采样点
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    # 最远点采样,得到采样点的索引id
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    # 根据索引得到最远点采样点的坐标
    new_xyz = index_points(xyz, fps_idx)
    # 利用Ball query对每个最远点采样的球邻域得到nsample个采样点的索引
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    # 得到上述求得的group的坐标
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    # grouped_xyz减去最远点采样点的坐标,即中心值
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)

    # 如果每个点上面有新的特征的维度,则用新的特征与旧的特征拼接,否则直接返回旧的特征
    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points


def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

Set Abstraction level

Set Abstraction level 分为普通版本和 Multi-Scale 版本,分别命名为 PointNetSetAbstraction 和PointNetSetAbstractionMsg。

# 下面是普通版本的 Set Abstraction level
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
	'''
       	Input:
       		npoint: Number of point for FPS sampling
       		radius: Radius for ball query
       		nsample: Number of point for each ball query
       		in_channel: the dimention of channel
       		mlp: A list for mlp input-output channel, such as [64, 64, 128]
       		group_all: bool type for group_all or not
       	'''
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        # Sample + Group,得到group
        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        # 利用1x1的2d的卷积相当于把每个group当成一个通道,共npoint个通道,对[C+D, nsample]的维度上做逐像素的卷积,
        # 结果相当于对单个C+D维度做1d的卷积
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        # 对每个group做一个max pooling得到局部的全局特征
        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

# 下面是Multi-scale版本的Set Abstraction level,不同之处在于该版本输入的radius、nsample、mlp不再是单个
class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
	'''
       	Input:
       		npoint: Number of point for FPS sampling
       		radius_list: Multi Radius for ball query
       		nsample_list: Multi Number of point for each ball query
       		in_channel: the dimention of channel
       		mlp_list: Multi lists for mlp input-output channel, such as [64, 64, 128]
       		group_all: bool type for group_all or not
       	'''
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat

Feature Propagation level

class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            # 找到前三个距离最近的点
            dists = square_distance(xyz1, xyz2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            # 线性插值函数  
            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

2. 任务函数

由于下面的任务函数很好理解,所以下面只贴出了代码,并没有做详细的解释。

2.1 Classification

single scale group classification

class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=True):
        super(get_model, self).__init__()
        # 是否有法向信息作为输入特征
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        # Single Scale Set Abstraction level
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
        # mlp
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        # 调用Single Scale Set Abstraction level
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        # 调用 mlp
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x, l3_points

multi scale group classification

class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=True):
        super(get_model, self).__init__()
        in_channel = 3 if normal_channel else 0
        self.normal_channel = normal_channel
        # Multi Scale Set Abstraction level
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
        # mlp 
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        # 调用 Multi Scale Set Abstraction level
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        # 调用mlp
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x,l3_points

2.2 Part Segmentation

Single Scale Part Segmentation

class get_model(nn.Module):
    def __init__(self, num_classes, normal_channel=False):
        super(get_model, self).__init__()
        if normal_channel:
            additional_channel = 3
        else:
            additional_channel = 0
        self.normal_channel = normal_channel
        # Single Scale Set Abstraction level
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=6+additional_channel, mlp=[64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
        # Feature Propagation level
        self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256])
        self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128])
        self.fp1 = PointNetFeaturePropagation(in_channel=128+16+6+additional_channel, mlp=[128, 128, 128])
        # Fully Connected layer
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, cls_label):
        # Single Scale Set Abstraction layers
        B,C,N = xyz.shape
        if self.normal_channel:
            l0_points = xyz
            l0_xyz = xyz[:,:3,:]
        else:
            l0_points = xyz
            l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
        # FC layers
        feat =  F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l3_points

Multi Scale Part Segmentation

class get_model(nn.Module):
    def __init__(self, num_classes, normal_channel=False):
        super(get_model, self).__init__()
        if normal_channel:
            additional_channel = 3
        else:
            additional_channel = 0
        self.normal_channel = normal_channel
        # Multi Scale Set Abstraction level
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
        self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
        self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
        self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, cls_label):
        # Single Scale Set Abstraction level
        B,C,N = xyz.shape
        if self.normal_channel:
            l0_points = xyz
            l0_xyz = xyz[:,:3,:]
        else:
            l0_points = xyz
            l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l3_points

2.3 Semantic Segmentation

Single Scale Semantic Segmentation

class get_model(nn.Module):
    def __init__(self, num_classes):
        super(get_model, self).__init__()
        self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
        self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
        self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
        self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
        self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz):
        l0_points = xyz
        l0_xyz = xyz[:,:3,:]

        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)

        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)

        x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l4_points

Multi Scale Semantic Segmentation

class get_model(nn.Module):
    def __init__(self, num_classes):
        super(get_model, self).__init__()

        self.sa1 = PointNetSetAbstractionMsg(1024, [0.05, 0.1], [16, 32], 9, [[16, 16, 32], [32, 32, 64]])
        self.sa2 = PointNetSetAbstractionMsg(256, [0.1, 0.2], [16, 32], 32+64, [[64, 64, 128], [64, 96, 128]])
        self.sa3 = PointNetSetAbstractionMsg(64, [0.2, 0.4], [16, 32], 128+128, [[128, 196, 256], [128, 196, 256]])
        self.sa4 = PointNetSetAbstractionMsg(16, [0.4, 0.8], [16, 32], 256+256, [[256, 256, 512], [256, 384, 512]])
        self.fp4 = PointNetFeaturePropagation(512+512+256+256, [256, 256])
        self.fp3 = PointNetFeaturePropagation(128+128+256, [256, 256])
        self.fp2 = PointNetFeaturePropagation(32+64+256, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz):
        l0_points = xyz
        l0_xyz = xyz[:,:3,:]

        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)

        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)

        x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l4_points

如有错误,恳请大家指正

你可能感兴趣的:(深度几何学习,深度学习,自动驾驶,人工智能)