pytorch.gather()函数深入理解(dim=1,2,3三种维度分析)

首先我要吐槽torch.gather()函数的官方文档,请问它在说个啥?后来根据csdn以及自己的学习,总结出torch.gather()的用法:

首先,给出torch.gather()中的几个参数

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor

常用的就是input,dim,index三个参数:

  1. input: 你要输入的torch.tensor();
  2. dim: 要处理的维度,一个[ ]表示一个维度,比如[ [ 2,3 ] ]中的2和3就是在第二维,dim可以取0,1,2;
  3. index: 必须为torch.LongTensor()的类型,且维度大小必须和input相同,index中每一个值表示input在dim维中的下标,下标从0开始

说了这么多估计也没说明白,正常正常,先上几个例子自己理解理解:

1、dim=1时的情况:

input=torch.arange(15).view(3,5)
print("input:\n",input)

index1=torch.tensor([
    [1, 0],
    [0, 0],
    [1, 2]])
print("index:\n",index)

print("dim=1时:\n",torch.gather(input,dim=1,index=index1))

结果为:

input:
 tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
index:
 tensor([[1, 0],
        [0, 0],
        [1, 2]])
dim=1:
 tensor([[ 1,  0],
        [ 5,  5],
        [11, 12]])

dim=1表示取第二维,也就是第二个中括号中的元素进行处理,仔细观察index:

  1. index中[1,0]中的1表示在input中第二维下标为1的元素,也就是1
  2. index中[1,0]中的0表示在input中第二维下标为1的元素,也就是0
  3. index中[1,2]中的1表示在input中第二维第三组值下标为1的元素。也就是11
  4. index中[1,2]中的2表示在input中第二维第三组值下标为2的元素。也就是12

注意: dim=1时,index中组的个数要与input组的个数相同

2、dim=0时的情况:

input=torch.arange(15).view(3,5)
print("input:\n",input)

index1=torch.tensor([
    [1,0,0,0,0],
    [0,0,1,2,1],
    [1,2,0,0,0]])
print("index:\n",index1)

print("dim=0时:\n",torch.gather(input,dim=0,index=index1))

输出结果:

input:
 tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
index:
 tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 2, 1]])
dim=0:
 tensor([[ 5,  1,  2,  3,  4],
        [ 0,  1,  7, 13,  9]])

当dim=0时,表示在第一维中检索下标,input第一维度的数据可以看作:

[0,5,10],[1,6,11],[2,7,12],[3,8,13],[4,0,14]

那么在index中的[1,0,0,0,0]中的1表示在[0,5,10]中取下标为1的元素,也就是5。后面可以依次取出并集合到一个tensor中去。

注意: dim=0时,index每一组中的元素个数要与input中的元素个数相同,也就是都为5个。

3、dim=3时

相信对torch.gather()有一定的了解,那么在下面举例dim=3的情况:

input=torch.tensor([[
        [1,2,3],
        [4,5,6],
        [7,8,9]]
])
index1=torch.tensor([[
        [0,0],
        [0,0],
        [0,0]
]])
print("input:\n",input)
print("index:\n",index1)
print("dim=3时:\n",torch.gather(test,dim=2,index=index1))

结果为:

input:
 tensor([[[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]])
index:
 tensor([[[0, 0],
         [0, 0],
         [0, 0]]])
dim=3:
 tensor([[[1, 1],
         [4, 4],
         [7, 7]]])

以上就是有关gather()的笔记,不知道我是否有讲清楚,有帮助的点个赞吧~

你可能感兴趣的:(笔记,pytorch,python,深度学习)