PointCNN代码阅读笔记

Readme阅读

代码文件内容

  • X-conv及PointCNN架构:pointcnn.py
  • 分类任务超参数:pointcnn_cls.py
  • 分割任务相关:pointcnn_seg.py

关于 X-Conv 和 X-DeConv 的参数

  • 以shapenet_x8_2048_fps.py 文件的代码为例
xconv_param_name = ('K', 'D', 'P', 'C', 'links')
xconv_params = [dict(zip(xconv_param_name, xconv_param)) for xconv_param in
                [(8, 1, -1, 32 * x, []),
                 (12, 2, 768, 32 * x, []),
                 (16, 2, 384, 64 * x, []),
                 (16, 6, 128, 128 * x, [])]]

xdconv_param_name = ('K', 'D', 'pts_layer_idx', 'qrs_layer_idx')
xdconv_params = [dict(zip(xdconv_param_name, xdconv_param)) for xdconv_param in
                 [(16, 6, 3, 2),
                  (12, 6, 2, 1),
                  (8, 6, 1, 0),
                  (8, 4, 0, 0)]]

xconv_params

  • K:近邻数
  • D:膨胀率
  • P:输出代表点的数量(-1表示输出输入中的所有点)
  • C:输出通道数(深度)
  • links:DesNet风格的连接(元组形式,表示该层输入包含倒数X层的输出)

xdconv_params

  • K,D:与上同
  • pts_layer_idx :指定某层X-conv的输出作为X-dconv的输入
  • qrs_layer_idx :指定某层X-conv的输出与X-dconv输出混合

PointCNN.py

X-conv结构

def xconv(pts, fts, qrs, tag, N, K, D, P, C, C_pts_fts, is_training, with_X_transformation, depth_multiplier,sorting_method=None, with_global=False):

函数参数

  • pts:points
  • fts:features
  • qrs:queries
  • K:近邻数
  • D:膨胀率
  • P:点数目
  • C:通道数(深度)
  • with_X_transformation:是否需要X变换
_, indices_dilated = pf.knn_indices_general(qrs, pts, K * D, True)
indices = indices_dilated[:, :, ::D, :]
def knn_indices_general(queries, points, k, sort=True, unique=True):
    queries_shape = tf.shape(queries)
    batch_size = queries_shape[0]
    point_num = queries_shape[1]

    D = batch_distance_matrix_general(queries, points)
    if unique:
        prepare_for_unique_top_k(D, points)
    distances, point_indices = tf.nn.top_k(-D, k=k, sorted=sort)  # (N, P, K)
    batch_indices = tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), (1, point_num, k, 1))
    indices = tf.concat([batch_indices, tf.expand_dims(point_indices, axis=3)], axis=3)
    return -distances, indices
def batch_distance_matrix_general(A, B):
    r_A = tf.reduce_sum(A * A, axis=2, keep_dims=True)
    r_B = tf.reduce_sum(B * B, axis=2, keep_dims=True)
    m = tf.matmul(A, tf.transpose(B, perm=(0, 2, 1)))
    D = r_A - 2 * m + tf.transpose(r_B, perm=(0, 2, 1))
    return D
def prepare_for_unique_top_k(D, A):
    indices_duplicated = tf.py_func(find_duplicate_columns, [A], tf.int32)
    D += tf.reduce_max(D)*tf.cast(indices_duplicated, tf.float32)
  • 涉及维度变化,不太理解
nn_pts = tf.gather_nd(pts, indices, name=tag + 'nn_pts')  # (N, P, K, 3)
nn_pts_center = tf.expand_dims(qrs, axis=2, name=tag + 'nn_pts_center')  # (N, P, 1, 3)
nn_pts_local = tf.subtract(nn_pts, nn_pts_center, name=tag + 'nn_pts_local')  # (N, P, K, 3)
  • 得到相对坐标矩阵(第一步)
  • nn_pts_local应对应于公式中P'
nn_fts_from_pts_0 = pf.dense(nn_pts_local, C_pts_fts, tag + 'nn_fts_from_pts_0', is_training)
nn_fts_from_pts = pf.dense(nn_fts_from_pts_0, C_pts_fts, tag + 'nn_fts_from_pts', is_training)

if fts is None:
    nn_fts_input = nn_fts_from_pts
else:
    nn_fts_from_prev = tf.gather_nd(fts, indices, name=tag + 'nn_fts_from_prev')
    nn_fts_input = tf.concat([nn_fts_from_pts, nn_fts_from_prev], axis=-1, name=tag + 'nn_fts_input')

处理特征

  • 从点云坐标中提取特征(第二步)
  • 将提取的特征和原有特征合并(第三步)
if with_X_transformation:
    ######################## X-transformation #########################
    X_0 = pf.conv2d(nn_pts_local, K * K, tag + 'X_0', is_training, (1, K))
    X_0_KK = tf.reshape(X_0, (N, P, K, K), name=tag + 'X_0_KK')
    X_1 = pf.depthwise_conv2d(X_0_KK, K, tag + 'X_1', is_training, (1, K))
    X_1_KK = tf.reshape(X_1, (N, P, K, K), name=tag + 'X_1_KK')
    X_2 = pf.depthwise_conv2d(X_1_KK, K, tag + 'X_2', is_training, (1, K), activation=None)
    X_2_KK = tf.reshape(X_2, (N, P, K, K), name=tag + 'X_2_KK')
    fts_X = tf.matmul(X_2_KK, nn_fts_input, name=tag + 'fts_X')
    ###################################################################
else:
    fts_X = nn_fts_input

X变换矩阵

  • 一层卷积+两层深度可分卷积(不理解)
  • 卷积核为(1,k),将(N,P,K,3)提升至(N,P,K,K)(第四步)
  • 将得到的X变换矩阵与特征矩阵相乘,得到可以直接与卷积核卷积的特征(第五步)

PointCNN架构

train_val_cls.py

  • 数据预处理(切片,乱序)
  • 训练轮数,batch_size
  • 数据增强
  • tensorboard
  • 准确率与loss
    • acc:总准确率
    • macc:每类平均准确率

你可能感兴趣的:(PointCNN代码阅读笔记)