Pointnet++代码详解(三):query_ball_point函数

query_ball_point函数对应于Grouping layer, 这一层使用Ball query方法生成N'个局部区域,根据论文中的意思,这里有两个变量 ,一个是每个区域中点的数量K,另一个是球的半径。这里半径应该是占主导的,会在某个半径的球内找点,上限是K。球的半径和每个区域中点的数量都是人指定的。

query_ball_point函数用于寻找球形领域中的点。输入中radius为球形领域的半径,nsample为每个领域中要采样的点,new_xyz为S个球形领域的中心(由最远点采样在前面得出),xyz为所有的点云;输出为每个样本的每个球形领域的nsample个采样点集的索引[B,S,nsample]。

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3] ,s denotes the number of center points
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device)\
        .view(1, 1, N).repeat([B, S, 1])
    # sqrdists: [B, S, N] 记录中心点与所有点之间的欧几里德距离
    sqrdists = square_distance(new_xyz, xyz)
    # 找到所有距离大于radius^2的,其group_idx直接置为N;其余的保留原来的值
    group_idx[sqrdists > radius **2] = N
    # 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
    # group_first: [B, S, nsample], 实际就是把group_idx中的第一个点的值复制到[B, S, nsample]的维度,便利于后面的替换
    # 这里要用view是因为group_idx[:, :, 0]取出之后的tensor相当于二维Tensor,因此需要用view变成三维tensor
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    # 找到group_idx中值等于N的点,会输出0,1构成的三维Tensor,维度为[B,S,nsample]
    mask = group_idx == N
    # 将这些点的值替换为第一个点的值
    group_idx[mask] = group_first[mask]
    return group_idx

1、对于group_idx的理解:

    group_idx = torch.arange(N, dtype=torch.long).to(device)\
        .view(1, 1, N).repeat([B, S, 1])

N指的是一个样本中总的数据点的个数,用torch.arange(N)可以生成tensor([0,1,...,N-1]), 用.to(device)意思是说将生成的tensor([0,1,...,N-1])复制到的xyz所在的设备上,再用.view(1,1,N)则将tesor表示成tesnor([[[0,1,...,N-1]]])即有N列的意思,再用.repeat([B,S,1])则是说将原来的tensor在维度0上复制B个(原先只有1个),在维度1上复制S个,可以理解有B个batch,每个样本有S行N列,所以最后group_idx的维度为[B,S,N], 用代码来展示下:

import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
print("g0:",group_idx0)
print("g1:",group_idx1)
print("g2:",group_idx2)
#结果:
g0: tensor([0, 1, 2, 3, 4])
g1: tensor([[[0, 1, 2, 3, 4]]])
g2: tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])

2、对group_idx.sort的理解:

torch.sort(input, dim=-1, descending=False, out=None),dim=-1说的是最后一维,在源码中指的就是dim=2

Pointnet++代码详解(三):query_ball_point函数_第1张图片

a=torch.randn(2,3,4)
print("a",a)
print("dim=0",torch.sort(a,0))
print("dim=1",torch.sort(a,1))
print("dim=2",torch.sort(a,2))
print("dim=-1",torch.sort(a,-1))

#结果
a tensor([[[ 0.1644, -0.9524, -0.0522, -1.7683],
         [-0.0426, -1.3940, -0.9358, -2.5367],
         [ 0.6171,  0.2587,  1.6798,  0.3828]],

        [[ 1.0571, -0.2126, -0.1489,  0.5902],
         [ 0.1673, -0.5937, -0.3240,  1.1439],
         [-0.4273, -0.4449, -0.8735, -0.6969]]])
dim=0 (tensor([[[ 0.1644, -0.9524, -0.1489, -1.7683],
         [-0.0426, -1.3940, -0.9358, -2.5367],
         [-0.4273, -0.4449, -0.8735, -0.6969]],

        [[ 1.0571, -0.2126, -0.0522,  0.5902],
         [ 0.1673, -0.5937, -0.3240,  1.1439],
         [ 0.6171,  0.2587,  1.6798,  0.3828]]]))

dim=1 (tensor([[[-0.0426, -1.3940, -0.9358, -2.5367],
         [ 0.1644, -0.9524, -0.0522, -1.7683],
         [ 0.6171,  0.2587,  1.6798,  0.3828]],

        [[-0.4273, -0.5937, -0.8735, -0.6969],
         [ 0.1673, -0.4449, -0.3240,  0.5902],
         [ 1.0571, -0.2126, -0.1489,  1.1439]]])

dim=2 (tensor([[[-1.7683, -0.9524, -0.0522,  0.1644],
         [-2.5367, -1.3940, -0.9358, -0.0426],
         [ 0.2587,  0.3828,  0.6171,  1.6798]],

        [[-0.2126, -0.1489,  0.5902,  1.0571],
         [-0.5937, -0.3240,  0.1673,  1.1439],
         [-0.8735, -0.6969, -0.4449, -0.4273]]])

dim=-1 (tensor([[[-1.7683, -0.9524, -0.0522,  0.1644],
         [-2.5367, -1.3940, -0.9358, -0.0426],
         [ 0.2587,  0.3828,  0.6171,  1.6798]],

        [[-0.2126, -0.1489,  0.5902,  1.0571],
         [-0.5937, -0.3240,  0.1673,  1.1439],
         [-0.8735, -0.6969, -0.4449, -0.4273]]])

经过group_idx.sort(dim=-1)[0][:, :, :nsample]之后group_idx的维度为[B,S,nsample].

3、对group_idx[mask] = group_first[mask]的理解:

import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])

mask= group_idx2 == 3
print(mask)
print(group_idx2[mask])
group_idx2[mask] =10
print(group_idx2)

#结果:
maks: tensor([[[0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]],

        [[0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]],

        [[0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]]], dtype=torch.uint8)

group_idx2[mask]: tensor([3, 3, 3, 3, 3, 3])

group_idx2: tensor([[[ 0,  1,  2, 10,  4],
         [ 0,  1,  2, 10,  4]],

        [[ 0,  1,  2, 10,  4],
         [ 0,  1,  2, 10,  4]],

        [[ 0,  1,  2, 10,  4],
         [ 0,  1,  2, 10,  4]]])

我们可以得出这样的结论: mask必须是一个 ByteTensor ,而且shape必须和 a一样 并且元素只能是0或者1 ,是将 mask中为1的元素所在的索引,在a中相同的的索引处替换为 value ,mask value必须同为tensor 

 

 

你可能感兴趣的:(PointNet++系列)