query_ball_point函数对应于Grouping layer, 这一层使用Ball query方法生成N'个局部区域,根据论文中的意思,这里有两个变量 ,一个是每个区域中点的数量K,另一个是球的半径。这里半径应该是占主导的,会在某个半径的球内找点,上限是K。球的半径和每个区域中点的数量都是人指定的。
query_ball_point函数用于寻找球形领域中的点。输入中radius为球形领域的半径,nsample为每个领域中要采样的点,new_xyz为S个球形领域的中心(由最远点采样在前面得出),xyz为所有的点云;输出为每个样本的每个球形领域的nsample个采样点集的索引[B,S,nsample]。
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] ,s denotes the number of center points
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])
# sqrdists: [B, S, N] 记录中心点与所有点之间的欧几里德距离
sqrdists = square_distance(new_xyz, xyz)
# 找到所有距离大于radius^2的,其group_idx直接置为N;其余的保留原来的值
group_idx[sqrdists > radius **2] = N
# 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
# 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
# group_first: [B, S, nsample], 实际就是把group_idx中的第一个点的值复制到[B, S, nsample]的维度,便利于后面的替换
# 这里要用view是因为group_idx[:, :, 0]取出之后的tensor相当于二维Tensor,因此需要用view变成三维tensor
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
# 找到group_idx中值等于N的点,会输出0,1构成的三维Tensor,维度为[B,S,nsample]
mask = group_idx == N
# 将这些点的值替换为第一个点的值
group_idx[mask] = group_first[mask]
return group_idx
1、对于group_idx的理解:
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])
N指的是一个样本中总的数据点的个数,用torch.arange(N)可以生成tensor([0,1,...,N-1]), 用.to(device)意思是说将生成的tensor([0,1,...,N-1])复制到的xyz所在的设备上,再用.view(1,1,N)则将tesor表示成tesnor([[[0,1,...,N-1]]])即有N列的意思,再用.repeat([B,S,1])则是说将原来的tensor在维度0上复制B个(原先只有1个),在维度1上复制S个,可以理解有B个batch,每个样本有S行N列,所以最后group_idx的维度为[B,S,N], 用代码来展示下:
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
print("g0:",group_idx0)
print("g1:",group_idx1)
print("g2:",group_idx2)
#结果:
g0: tensor([0, 1, 2, 3, 4])
g1: tensor([[[0, 1, 2, 3, 4]]])
g2: tensor([[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])
2、对group_idx.sort的理解:
torch.sort(input, dim=-1, descending=False, out=None),dim=-1说的是最后一维,在源码中指的就是dim=2
a=torch.randn(2,3,4)
print("a",a)
print("dim=0",torch.sort(a,0))
print("dim=1",torch.sort(a,1))
print("dim=2",torch.sort(a,2))
print("dim=-1",torch.sort(a,-1))
#结果
a tensor([[[ 0.1644, -0.9524, -0.0522, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[ 1.0571, -0.2126, -0.1489, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[-0.4273, -0.4449, -0.8735, -0.6969]]])
dim=0 (tensor([[[ 0.1644, -0.9524, -0.1489, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[-0.4273, -0.4449, -0.8735, -0.6969]],
[[ 1.0571, -0.2126, -0.0522, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[ 0.6171, 0.2587, 1.6798, 0.3828]]]))
dim=1 (tensor([[[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.1644, -0.9524, -0.0522, -1.7683],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[-0.4273, -0.5937, -0.8735, -0.6969],
[ 0.1673, -0.4449, -0.3240, 0.5902],
[ 1.0571, -0.2126, -0.1489, 1.1439]]])
dim=2 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])
dim=-1 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])
经过group_idx.sort(dim=-1)[0][:, :, :nsample]之后group_idx的维度为[B,S,nsample].
3、对group_idx[mask] = group_first[mask]的理解:
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
mask= group_idx2 == 3
print(mask)
print(group_idx2[mask])
group_idx2[mask] =10
print(group_idx2)
#结果:
maks: tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]]], dtype=torch.uint8)
group_idx2[mask]: tensor([3, 3, 3, 3, 3, 3])
group_idx2: tensor([[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]]])
我们可以得出这样的结论: mask必须是一个 ByteTensor ,而且shape必须和 a一样 并且元素只能是0或者1 ,是将 mask中为1的元素所在的索引,在a中相同的的索引处替换为 value ,mask value必须同为tensor