讲清楚tensor.gather(dim,index)和torch.gather(input, dim, index),举例,应用

目录

  • 前言
  • 正题
  • 举例,维度为2
  • 举例,维度为3
  • 应用
    • 常规思路
    • 用gather

前言

tensor.gather(dim, index)torch.gather(input, dim, index)两者没有本质差别。
这里挑tensor.gather(dim, index)来讲。

正题

官网解释

中文解释:
输入dim和index,index和tensor的维度数目一样,比如都是3个维度的数组(A,B,C)这种。
dim指明你的索引是在第几维,index要求必须是和输入tensor相同维度的张量,返回的是这些索引对应的值,返回张量的size与index相同。
对于你的index中的元素值,它有自己的索引,此时要指定是某一个维度dim,将这个元素自己的索引中对应dim的维度改变为该元素值,其他维度上的值不变,然后根据这个新的索引在tensor中取索引值。

官网的3D情况说明(看完下面的举例你会觉得很清晰):

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

举例,维度为2

a = torch.randn((3,4))
print('a\n',a)
i1 = torch.tensor([[1,2],
                  [1,1],
                  [0,2],
                  [1,1]])
i2 = torch.tensor([[1,2],
                  [1,1],
                  [0,2]])
print('ans1\n',a.gather(0,i1))
print('ans2\n',a.gather(1,i2)) 
output:
a
 tensor([[-0.2655,  0.3619, -0.7515, -0.8025],
        [ 0.5486,  0.0390, -0.3317, -0.1171],
        [ 2.1218, -1.6343, -1.0830,  1.3824]])
ans1
 tensor([[ 0.5486, -1.6343],
        [ 0.5486,  0.0390],
        [-0.2655, -1.6343],
        [ 0.5486,  0.0390]])
ans2
 tensor([[ 0.3619, -0.7515],
        [ 0.0390,  0.0390],
        [ 2.1218, -1.0830]])

以索引i1为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=0,所以在元素索引的0维上改变为元素值:
1的索引[0][0]->[1][0],然后在变量a中索引到了值a[1][0]=0.5486
2的索引[0][1]->[2][1],然后在变量a中索引到了值a[2][1]=-1.6343
所以这个ans1我们可以写为:

tensor([[a[1][0], a[2][1]],
        [a[1][0], a[1][1]],
        [a[0][0], a[2][1]],
        [a[1][0], a[1][1]]])

可以对比一下刚刚的数据:
i1 = torch.tensor([[1,2],   
                   [1,1],
                   [0,2],
                   [1,1]])
  
a = tensor([[-0.2655,  0.3619, -0.7515, -0.8025],
        	[ 0.5486,  0.0390, -0.3317, -0.1171],
        	[ 2.1218, -1.6343, -1.0830,  1.3824]])
 
ans1 = tensor( [[ 0.5486, -1.6343],
				[ 0.5486,  0.0390],
				[-0.2655, -1.6343],
				[ 0.5486,  0.0390]])

若索引i2为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=1,所以在元素索引的1维上改变为元素值:
1的索引[0][0]->[0][1],然后在变量a中索引到了值a[0][1]=0.3619
2的索引[0][1]->[0][2],然后在变量a中索引到了值a[0][2]=-0.7515
所以这个ans2我们可以写为:

tensor([[a[0][1], a[0][2]],
        [a[1][1], a[1][1]],
        [a[2][0], a[2][2]]])

可以对比一下刚刚的数据:
a = tensor([[-0.2655,  0.3619, -0.7515, -0.8025],
        	[ 0.5486,  0.0390, -0.3317, -0.1171],
        	[ 2.1218, -1.6343, -1.0830,  1.3824]])
        	
i2 = torch.tensor([[1,2],
                   [1,1],
                   [0,2]])

ans2 = tensor([[ 0.3619, -0.7515],
        	   [ 0.0390,  0.0390],
               [ 2.1218, -1.0830]])

举例,维度为3

a = torch.randn((3,5,3))
print('a\n',a)
i1 = torch.tensor([[[1,2],
                    [1,1],
                    [0,2],
                    [1,1]]])
print('ans\n',a.gather(2,i1))
output:
a
tensor([[[-0.0114, -1.0284, -0.5340],
        [ 0.5844,  1.4223,  0.4038],
        [ 0.0575,  1.0408,  0.4988],
        [ 0.3994, -0.0080,  0.5033],
        [-1.3644,  0.4155, -0.6559]],

       [[ 1.7330,  0.2755, -0.9000],
        [-0.2527,  0.5685,  1.6011],
        [ 2.0909, -0.4134, -1.2176],
        [ 0.8040,  1.1630,  0.3964],
        [-0.6463,  0.2030, -0.8429]],

       [[ 1.0368, -0.7876,  1.3825],
        [ 1.5968, -1.1934,  0.9004],
        [-0.6002, -0.8837, -2.1700],
        [-0.9114, -0.1575,  1.3854],
        [-0.0854,  0.5144,  0.0932]]])
ans
tensor([[[-1.0284, -0.5340],
         [ 1.4223,  1.4223],
         [ 0.0575,  0.4988],
         [-0.0080, -0.0080]]])

相同的道理,他的第三行元素是[0,2],
0在这个数组中的索引是[0][2][0],
2在这个数组中的索引是[0][2][1],
选择的dim=2,所以在元素索引的2维上改变为元素值:
0的索引[0][2][0]->[0][2][0],然后在变量a中索引到了值a[0][2][0]=0.0575
2的索引[0][2][1]->[0][2][2],然后在变量a中索引到了值a[0][2][2]=0.4988
其他的话就一样了,这里不再赘述。

这时再看官网的说明,很清晰了吧(水字数)!

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

应用

假设我现在有一个神经网络的输出是点云数据。
例如这个输出的维度是output.shape=[128,1000,3] ,
可以理解为batch_size为128,然后每个batch有1000个点,每个点的坐标是 xyz 3个特征。
然后我现在有一个索引 index.shape=[128,3000](这3000个索引肯定是有重复的,我要做的就是根据每个batch中这3000个索引把output中的值放入到这个结果result中。

常规思路

遍历所以batch,将output每个batch中对应index的值赋值给result的每个batch。

result = torch.empty((128,3000,3))
for i in range(128):
    result[i]=output[i][index[i]]

用gather

索引取得是xyz,可以先扩展一维,然后复制3份,取dim为1,表示现在result[i][j][k] = output[i][index[i][j][k]][k]
即对于每个batch,都把每行换成了此行所索引的output中的对应行值。

index=index.unsqueeze(-1).repeat(-1,-1,3) #(128,3000,3)
result = output.gather(dim=1,index=index)

你可能感兴趣的:(笔记,pytorch,python,pytorch,深度学习)