tensor.gather(dim, index)
和torch.gather(input, dim, index)
两者没有本质差别。
这里挑tensor.gather(dim, index)
来讲。
官网解释
中文解释:
输入dim和index,index和tensor的维度数目一样,比如都是3个维度的数组(A,B,C)这种。
dim指明你的索引是在第几维,index要求必须是和输入tensor相同维度的张量,返回的是这些索引对应的值,返回张量的size与index相同。
对于你的index中的元素值,它有自己的索引,此时要指定是某一个维度dim,将这个元素自己的索引中对应dim的维度改变为该元素值,其他维度上的值不变,然后根据这个新的索引在tensor中取索引值。
官网的3D情况说明(看完下面的举例你会觉得很清晰):
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
a = torch.randn((3,4))
print('a\n',a)
i1 = torch.tensor([[1,2],
[1,1],
[0,2],
[1,1]])
i2 = torch.tensor([[1,2],
[1,1],
[0,2]])
print('ans1\n',a.gather(0,i1))
print('ans2\n',a.gather(1,i2))
output:
a
tensor([[-0.2655, 0.3619, -0.7515, -0.8025],
[ 0.5486, 0.0390, -0.3317, -0.1171],
[ 2.1218, -1.6343, -1.0830, 1.3824]])
ans1
tensor([[ 0.5486, -1.6343],
[ 0.5486, 0.0390],
[-0.2655, -1.6343],
[ 0.5486, 0.0390]])
ans2
tensor([[ 0.3619, -0.7515],
[ 0.0390, 0.0390],
[ 2.1218, -1.0830]])
以索引i1
为例,他的首行元素是[1,2]
,
1在这个数组中的索引是[0][0]
,
2在这个数组中的索引是[0][1]
,
选择的dim=0
,所以在元素索引的0维上改变为元素值:
1的索引[0][0]->[1][0]
,然后在变量a
中索引到了值a[1][0]=0.5486
2的索引[0][1]->[2][1]
,然后在变量a
中索引到了值a[2][1]=-1.6343
所以这个ans1
我们可以写为:
tensor([[a[1][0], a[2][1]],
[a[1][0], a[1][1]],
[a[0][0], a[2][1]],
[a[1][0], a[1][1]]])
可以对比一下刚刚的数据:
i1 = torch.tensor([[1,2],
[1,1],
[0,2],
[1,1]])
a = tensor([[-0.2655, 0.3619, -0.7515, -0.8025],
[ 0.5486, 0.0390, -0.3317, -0.1171],
[ 2.1218, -1.6343, -1.0830, 1.3824]])
ans1 = tensor( [[ 0.5486, -1.6343],
[ 0.5486, 0.0390],
[-0.2655, -1.6343],
[ 0.5486, 0.0390]])
若索引i2
为例,他的首行元素是[1,2]
,
1在这个数组中的索引是[0][0]
,
2在这个数组中的索引是[0][1]
,
选择的dim=1
,所以在元素索引的1维上改变为元素值:
1的索引[0][0]->[0][1]
,然后在变量a
中索引到了值a[0][1]=0.3619
2的索引[0][1]->[0][2]
,然后在变量a
中索引到了值a[0][2]=-0.7515
所以这个ans2
我们可以写为:
tensor([[a[0][1], a[0][2]],
[a[1][1], a[1][1]],
[a[2][0], a[2][2]]])
可以对比一下刚刚的数据:
a = tensor([[-0.2655, 0.3619, -0.7515, -0.8025],
[ 0.5486, 0.0390, -0.3317, -0.1171],
[ 2.1218, -1.6343, -1.0830, 1.3824]])
i2 = torch.tensor([[1,2],
[1,1],
[0,2]])
ans2 = tensor([[ 0.3619, -0.7515],
[ 0.0390, 0.0390],
[ 2.1218, -1.0830]])
a = torch.randn((3,5,3))
print('a\n',a)
i1 = torch.tensor([[[1,2],
[1,1],
[0,2],
[1,1]]])
print('ans\n',a.gather(2,i1))
output:
a
tensor([[[-0.0114, -1.0284, -0.5340],
[ 0.5844, 1.4223, 0.4038],
[ 0.0575, 1.0408, 0.4988],
[ 0.3994, -0.0080, 0.5033],
[-1.3644, 0.4155, -0.6559]],
[[ 1.7330, 0.2755, -0.9000],
[-0.2527, 0.5685, 1.6011],
[ 2.0909, -0.4134, -1.2176],
[ 0.8040, 1.1630, 0.3964],
[-0.6463, 0.2030, -0.8429]],
[[ 1.0368, -0.7876, 1.3825],
[ 1.5968, -1.1934, 0.9004],
[-0.6002, -0.8837, -2.1700],
[-0.9114, -0.1575, 1.3854],
[-0.0854, 0.5144, 0.0932]]])
ans
tensor([[[-1.0284, -0.5340],
[ 1.4223, 1.4223],
[ 0.0575, 0.4988],
[-0.0080, -0.0080]]])
相同的道理,他的第三行元素是[0,2]
,
0在这个数组中的索引是[0][2][0]
,
2在这个数组中的索引是[0][2][1]
,
选择的dim=2
,所以在元素索引的2维上改变为元素值:
0的索引[0][2][0]->[0][2][0]
,然后在变量a
中索引到了值a[0][2][0]=0.0575
2的索引[0][2][1]->[0][2][2]
,然后在变量a
中索引到了值a[0][2][2]=0.4988
其他的话就一样了,这里不再赘述。
这时再看官网的说明,很清晰了吧(水字数)!
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
假设我现在有一个神经网络的输出是点云数据。
例如这个输出的维度是output.shape=[128,1000,3] ,
可以理解为batch_size为128,然后每个batch有1000个点,每个点的坐标是 xyz 3个特征。
然后我现在有一个索引 index.shape=[128,3000](这3000个索引肯定是有重复的,我要做的就是根据每个batch中这3000个索引把output中的值放入到这个结果result中。
遍历所以batch,将output每个batch中对应index的值赋值给result的每个batch。
result = torch.empty((128,3000,3))
for i in range(128):
result[i]=output[i][index[i]]
索引取得是xyz,可以先扩展一维,然后复制3份,取dim为1,表示现在result[i][j][k] = output[i][index[i][j][k]][k]
即对于每个batch,都把每行换成了此行所索引的output中的对应行值。
index=index.unsqueeze(-1).repeat(-1,-1,3) #(128,3000,3)
result = output.gather(dim=1,index=index)