结合代码理解Pointnet++网络结构

前言

Pointnet提取的全局特征能够很好地完成分类任务,由于网络将所有的点最大池化为了一个全局特征,因此局部点与点之间的联系并没有被网络学习到,导致网络的输出缺乏点云的局部结构特征,因此PointNet对于场景的分割效果十分一般。在点云分类和物体的Part Segmentation中,这样的问题可以通过中心化物体的坐标轴部分地解决,但在场景分割中很难去解决。
原文地址:https://arxiv.org/abs/1706.02413

因此作者在此基础上又提出了能够实现点云作多层特征提取的Pointnet++网络,网络结构如下:
结合代码理解Pointnet++网络结构_第1张图片
图片来源:https://arxiv.org/abs/1706.02413

网络的基本组成

下面介绍上图中的网络设计,传统的CNN在进行特征学习时,使用卷积核作为局部感受野,每层的卷积核共享权值,进过多层的特征学习,最后的输出会包含图像的局部特征信息。通过改变中借鉴CNN的采样思路,采取分层特征学习,即在小区域中使用点云采样+成组+提取局部特征(S+G+P)的方式,包含这三部分的机构称为Set Abstraction

  • Sampling:随机选择一个初始点,然后依次利用FPS(最远点采样)进行采样,直到达到目标点数;
  • Grouping:以采样点为中心,利用Ball Query划一个R为半径的球,将里面包含的点云作为一簇成组;
  • Pointnet: 对Sampling+Grouping以后的点云进行局部的全局特征提取。

以2D点图为例,整个SA(Set Abstraction)三步的实现过程表示如下:
结合代码理解Pointnet++网络结构_第2张图片
结合代码理解Pointnet++网络结构_第3张图片图片来源:https://arxiv.org/abs/1706.02413

每层新的中心点都是从上一层抽取的特征子集,中心点的个数就是成组的点集数,随着层数增加,中心点的个数也会逐渐降低,抽取到点云的局部结构特征。

针对非均匀点云情况

当点云不均匀时,每个子区域中如果在分区的时候使用相同的球半径,会导致部分稀疏区域采样点过小。

文中提出**多尺度成组 (MSG)多分辨率成组 (MRG)**两种解决办法。
结合代码理解Pointnet++网络结构_第4张图片

简单概括这两种采样方法:

  • **多尺度成组(MSG):**对于选取的一个中心点设置多个半径进行成组,并将经过PointNet对每个区域抽取后的特征进行拼接(concat)来当做该中心点的特征,个人认为这种做法会产生很多特征重叠,结果会可以保留和突出(边际叠加)更多局部关键的特征,但是这种方式不同范围内计算的权值却很难共享,计算量会变大很多。
  • **多分辨率成组(MRG):**对不同特征层上(分辨率)提取的特征再进行concat,以上图右图为例,最后的concat包含左右两个部分特征,分别来自底层和高层的特征抽取,对于low level点云成组后经过一个pointnet和high level的进行concat,思想是特征的抽取中的跳层连接。当局部点云区域较稀疏时,上层提取到的特征可靠性可能比底层更差,因此考虑对底层特征提升权重。当然,点云密度较高时能够提取到的特征也会更多。这种方法优化了直接在稀疏点云上进行特征抽取产生的问题,且相对于MSG的效率也较高。

在该网络中作者使用了对输入点云进行随机采样(丢弃)random input dropout(DP)方法。Dropout的设计本身是为了降低过拟合,增强模型的鲁棒性,结果显示对于分类任务的效果也有不错的提升,作者给了一个对比图:
结合代码理解Pointnet++网络结构_第5张图片
本文中使用的缩写说明:

  • SA:set abstraction 点集抽取模块
  • FC:fully connected layers 全连接层
  • FP:feature
    propagation 特征传播模块(跨层连接,多个全连接)

SA模块的代码实现

  • utils/pointnet_util.py/ 中采样成组的代码具体实现。
def sample_and_group(npoint, radius, nsample, xyz, points, knn=False, use_xyz=True):
    '''
    输入参数说明:
    Input:
        npoint: int32,中心点的数量(分组数)
        radius: float32,ball quary的球半径大小
        nsample: int32,区域内采样到的点数
        xyz: (batch_size, ndataset, 3) TF tensor,例如:分类任务起始值(32,1024,3)
        points: (batch_size, ndataset, channel) TF tensor, 如果为None则等于xyz
        knn: bool, True表示使用KNN方法采样,否则使用球半径搜索
        use_xyz: bool, True 表示抽取的局部点的特征与xyz进行concat, 否则不进行,默认为True
        
    输出参数说明:
    Output:
        new_xyz: (batch_size, npoint, 3) TF tensor
        new_points: (batch_size, npoint, nsample, 3+channel) TF tensor,点的特征进行了concat
        idx: (batch_size, npoint, nsample) TF tensor, 采样的局部区域内点的索引值
        grouped_xyz: (batch_size, npoint, nsample, 3) TF tensor, 通过减去xyz对点进行区域归一化
        注:源码中没有tf_ops/grouping和sampling/下没有放编译生成对应的链接库.so文件,可能要重新编译才能执行相应的py脚本
    '''
	#1.对原始点云输入进行采样和分组
    new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz)) # (batch_size, npoint, 3)
    if knn:
        _,idx = knn_point(nsample, xyz, new_xyz)
    else:
        idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization,减去中心点坐标进行区域坐标归一化
    
    #2.对高层次特征进行分组
    if points is not None:
        grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel)
        if use_xyz:
            new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz

#在最后一次SA操作中,需要对全部特征进行采样分组
def sample_and_group_all(xyz, points, use_xyz=True):
    '''
    #输出变为三个参数,功能同上
    Inputs:
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor
        use_xyz: bool
    输出:
    Outputs:
        new_xyz: (batch_size, 1, 3) as (0,0,0)
        new_points: (batch_size, 1, ndataset, 3+channel) TF tensor
    Note:
       等价于sample_and_group(npoint=1, radius=inf)以(0,0,0)为重心
    '''
    batch_size = xyz.get_shape()[0].value
    nsample = xyz.get_shape()[1].value
    new_xyz = tf.constant(np.tile(np.array([0,0,0]).reshape((1,1,3)), (batch_size,1,1)),dtype=tf.float32) # (batch_size, 1, 3)
    idx = tf.constant(np.tile(np.array(range(nsample)).reshape((1,1,nsample)), (batch_size,1,1)))
    grouped_xyz = tf.reshape(xyz, (batch_size, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3)
    if points is not None:
        if use_xyz:
            new_points = tf.concat([xyz, points], axis=2) # (batch_size, 16, 259)
        else:
            new_points = points
        new_points = tf.expand_dims(new_points, 1) # (batch_size, 1, 16, 259)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points, idx, grouped_xyz


def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- 最远点采样点数(中心点数/成组数)
            radius: float32 -- 局部区域的搜索半径
            nsample: int32 -- 每个区域内的采样点数
            mlp: list of int32 -- 对每个点进行MLP的网络(输出)大小
            mlp2: list of int32 -- 对每个区域进行MLP的网络(输出)大小
            group_all: bool -- 如果为True,则重写npoint, radius and nsample为None
            use_xyz: bool, True 表示抽取的局部点的特征与xyz进行concat, 否则不进行
            use_nchw: bool, True, 使用NCHW点云数据格式进行卷积, 作者指出这样比NHWC格式的计算更快
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- 区域索引
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        # Sample and Grouping
        if group_all:
            nsample = xyz.get_shape()[1].value
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)
        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)

        # Point Feature Embedding
        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])#nchw->nwch
        for i, num_out_channel in enumerate(mlp):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1],
                                        bn=bn, is_training=is_training,
                                        scope='conv%d'%(i), bn_decay=bn_decay,
                                        data_format=data_format) 
        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])#nchw->nhwc
    """
    省略 some code(区域max pooling)
    """
    
    
#针对稀疏点云加入多尺度采样(msg)
def pointnet_sa_module_msg(xyz, points, npoint, radius_list, nsample_list, mlp_list, is_training, bn_decay, scope, bn=True, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) module with Multi-Scale Grouping (MSG)
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: list of float32 -- search radius in local region
            nsample: list of int32 -- how many points in each local region
            mlp: list of list of int32 -- output size for MLP on each point
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, \sum_k{mlp[k][-1]}) TF tensor
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz))
        new_points_list = []
        for i in range(len(radius_list)):
            radius = radius_list[i]
            nsample = nsample_list[i]
            idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
            grouped_xyz = group_point(xyz, idx)
            grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1])
            if points is not None:
                grouped_points = group_point(points, idx)
                if use_xyz:
                    grouped_points = tf.concat([grouped_points, grouped_xyz], axis=-1)
            else:
                grouped_points = grouped_xyz
            if use_nchw: grouped_points = tf.transpose(grouped_points, [0,3,1,2])
            for j,num_out_channel in enumerate(mlp_list[i]):
                grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1,1],
                                                padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                                scope='conv%d_%d'%(i,j), bn_decay=bn_decay)
            if use_nchw: grouped_points = tf.transpose(grouped_points, [0,2,3,1])
            new_points = tf.reduce_max(grouped_points, axis=[2])
            new_points_list.append(new_points)
        new_points_concat = tf.concat(new_points_list, axis=-1)
        return new_xyz, new_points_concat


def pointnet_fp_module(xyz1, xyz2, points1, points2, mlp, is_training, bn_decay, scope, bn=True):
    ''' PointNet Feature Propogation (FP) Module
    	FP层,作用是更新从插值操作和跳层连接合并来的特征
        Input:                                                                                                      
            xyz1: (batch_size, ndataset1, 3) TF tensor                                                              
            xyz2: (batch_size, ndataset2, 3) TF tensor, sparser than xyz1                                           
            points1: (batch_size, ndataset1, nchannel1) TF tensor                                                   
            points2: (batch_size, ndataset2, nchannel2) TF tensor
            mlp: list of int32 --对给个点进行mlp后的输出特征维度大小                                                 
        Return:
            new_points: (batch_size, ndataset1, mlp[-1]) TF tensor
            注:这一部分会用到插值模块,源码中带有tf_ops/3d_interpolation/tf_interpolate_so.so文件可以使用,不用重新编译。不同于需要进行编译的grouping和sampling操作。
    '''
    with tf.variable_scope(scope) as sc:
        dist, idx = three_nn(xyz1, xyz2)
        dist = tf.maximum(dist, 1e-10)
        norm = tf.reduce_sum((1.0/dist),axis=2,keep_dims=True)
        norm = tf.tile(norm,[1,1,3])
        weight = (1.0/dist) / norm
        interpolated_points = three_interpolate(points2, idx, weight)

        if points1 is not None:
            new_points1 = tf.concat(axis=2, values=[interpolated_points, points1]) # B,ndataset1,nchannel1+nchannel2
        else:
            new_points1 = interpolated_points
        new_points1 = tf.expand_dims(new_points1, 2)
        for i, num_out_channel in enumerate(mlp):
            new_points1 = tf_util.conv2d(new_points1, num_out_channel, [1,1],
                                         padding='VALID', stride=[1,1],
                                         bn=bn, is_training=is_training,
                                         scope='conv_%d'%(i), bn_decay=bn_decay)
        new_points1 = tf.squeeze(new_points1, [2]) # B,ndataset1,mlp[-1]
        return new_points1

以上是SA和FP部分的代码实现,接下来对分类任务的代码进行解读。

单尺度成组(SSG)分类网络的实现

以最基础的单尺度采样分组设计为例,结合代码了解模型的搭建过程。

  • models/pointnet2_cls_ssg.py /
def get_model(point_cloud, is_training, bn_decay=None):
    """ Classification PointNet, input is BxNx3, output Bx40 """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    end_points = {}
    l0_xyz = point_cloud
    l0_points = None
    end_points['l0_xyz'] = l0_xyz
    # Set abstraction layers
    # Note: When using NCHW for layer 2, we see increased GPU memory usage (in TF1.4).
    # So we only use NCHW for layer 1 until this issue can be resolved.
    """
    调用三次SA模块+三次全连接层+两次dropout=0.5,和PointNet一样,除最后一层外,在所有的全连接层后都会进行批量归一化操作+ReLU操作:
    SA(512, 0.2, [64, 64, 128]) → SA(128, 0.4, [128, 128, 256]) → SA([256, 512, 1024]) →
FC(512, 0.5) → FC(256, 0.5) → FC(K)
    """
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1', use_nchw=True)
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=None, radius=None, nsample=None, mlp=[256,512,1024], mlp2=None, group_all=True, is_training=is_training, bn_decay=bn_decay, scope='layer3')
     # Fully connected layers
    net = tf.reshape(l3_points, [batch_size, -1])
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp2')
    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')

    return net, end_points
    
    """
    对于多尺度的分类网络模型(MSG)对应于pointnet2_cls_msg.py,这里的半径和mlp维度都分别转变为向量和数组表示形式,整体的计算过程如下:    
SA(512, [0.1, 0.2, 0.4], [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) →
SA(128, [0.2, 0.4, 0.8], [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) →
SA([256, 512, 1024]) → F C(512, 0.5) → F C(256, 0.5) → F C(K)
    对于多分辨率分类模型(MRG),作者在附录中只是给出了设计的步骤,实现源码没有给出
    """

文章给出了针对ModelNet40S数据集上的分割模型的效果比较:
结合代码理解Pointnet++网络结构_第6张图片
相比于Pointnet的结果,Pointnet++在此有小幅度的提升。

对于分割部分,会单独进行一次总结,文中给出的分割效果对比图:
结合代码理解Pointnet++网络结构_第7张图片
结果显示在场景分割网络中,准确度关系为:MSG+DP > MRG+DP > SSG> PointNet

源码其余部分的介绍不详细展开,根据个人理解将源码的结构与功能设计展示如下:
结合代码理解Pointnet++网络结构_第8张图片

结语

本文主要结合代码层面总结了pointnet++网络设计以及分类任务的实现。重点理解pointnet++是如何利用set abstraction(SA)这种结构学习到局部结构上的特征,并通过跳步连接和多尺度采样(MSG+DP)来提高模型对点云的分割准确性。可以注意到pointnet++中在特征提取时使用pointnet网络,但是最后的结果的鲁棒性在不添加其他设计的情况下没有原网络好,并且作者没有继续使用T-net进行点云对齐的方法。
博客内容有很多理解不足之处,请多多交流指正,接下来将在此基础上继续进行相关论文的学习。

参考源码地址:
1.原论文实现代码
https://github.com/charlesq34/pointnet2
2.基于pytorch实现:
https://github.com/erikwijmans/Pointnet2_PyTorch
https://github.com/yanx27/Pointnet_Pointnet2_pytorch

你可能感兴趣的:(计算机视觉学习,python,深度学习,计算机视觉)