pytorch中gather函数的理解。

函数torch.gather(input, dim, index, out=None) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
对一个 3 维张量,输出可以定义为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Parameters:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
  • out (Tensor, optional) – 目标张量

使用说明举例:

  1. dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,
第一行分别是 按照index指定的,
input tensor的第一页 
第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,第五列下标为2元素
index-->0,1,2,0,2    output--> 18.,  26.,  22.,   1.,   0.
'''
  1. dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18.,   5.,   7.,  18.,   7.],
         [  3.,   3.,   3.,   3.,   3.],
         [ 28.,  28.,  28.,  28.,  28.]],

        [[ 10.,  20.,  20.,  20.,  20.],
         [  5.,   5.,   5.,   5.,   5.],
         [ 10.,  10.,  10.,  10.,  10.]]])
dim = 2的时候就安装 行 聚合了。参照上面的举一反三。
'''
  1. dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18.,  10.,  20.,   1.,  18.],
         [  3.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]],

        [[ 26.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  29.,  22.,  27.,   0.]]])
这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。
这里举的例子只有两页。所有index在0,1两个之间选择。
输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。
index [0,1,1,0,1]的意思就是。
在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。

'''

你可能感兴趣的:(pytorch中gather函数的理解。)