PyTorch中gather()函数的用法

torch.gather(input, dim, index, out=None) → Tensor

沿给定轴,按照索引张量将原张量的指定位置的元素重新聚合成一个新的张量

参数含义:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标(index类型需是torch.longTensor)
  • out (Tensor, optional) – 目标张量

官方给出的解释是这样的:
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out [i] [j] [k] = tensor [index[i][j][k]] [j] [k] # dim=0
out [i] [j] [k] = tensor [i] [index[i][j][k]] [k] # dim=1
out [i] [j] [k] = tensor [i] [j] [index[i][j][k]] # dim=2

对一个2维张量,输出可以定义为:
out [i] [j] = tensor [index[i][j]] [j] # dim=0
out [i] [j] = tensor [i] [index[i][j]] # dim=1

刚开始看上去很难理解,但经过研究之后会发现这个想表述的意思很简单,先给出几个代码例子让大家自行体会一下
先以3维张量为例
例 1 、维度 dim = 1

a = torch.arange(24).view(2,3,4)
print(a)
'''
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
         
         [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
'''
index = torch.LongTensor([[[1,0,1,2],
                           [2,0,2,1],
                           [0,2,1,1]],
                          [[2,1,0,1],
                           [1,2,0,2],
                           [0,2,1,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)

'''
True
tensor([[[ 4,  1,  6, 11],
         [ 8,  1, 10,  7],
         [ 0,  9,  6,  7]],

	 [[20, 17, 14, 19],
	  [16, 21, 14, 23],
	  [12, 21, 18, 23]]])
'''

解析:
在本例中,指定维度dim=1,那么就是从 列 来根据索引排列元素。比如 对于索引张量index的第一行元素[1,0,1,2],依次指

  • 在原张量的第一行中,第一处元素为该 列 索引为1对应的元素,即4;
  • 第二处元素为 第二列 索引为 0 的元素,即1;
  • 第三处元素为 第三列 索引为 1 的元素,即6;
  • 第四处元素为 第四列 索引为 2 的元素,即11

简单来说,对于 index[0]

tensor([[1, 0, 1, 2],
        [2, 0, 2, 1],
        [0, 2, 1, 1]])

聚合后的张量output[0]为

output[0] =[[[第一列第2个,第二列第1个,第三列第2个,第四列第3个],
          [第一列第3个,第二列第1个,第三列第3个,第四列第2个],
          [第一列第1个,第二列第3个,第三列第2个,第四列第2个]]
   

例2、维度dim = 2

c = torch.gather(a,2,index)
print(c)
'''
tensor([[[ 1,  0,  1,  2],
         [ 6,  4,  6,  5],
         [ 8, 10,  9,  9]],

        [[14, 13, 12, 13],
         [17, 18, 16, 18],
         [20, 22, 21, 22]]])
'''

简单来说,对于 index[0]

tensor([[1, 0, 1, 2],
        [2, 0, 2, 1],
        [0, 2, 1, 1]])

聚合后的张量output[0]为

output[0] =[[[第一行第2个,第一行第1个,第一行第2个,第一行第3个],
          [第二行第3个,第二行第1个,第二行第3个,第二行第2个],
          [第三行第1个,第三行第3个,第三行第2个,第三行第2个]] 

例3、维度 dim = 0

dim = 0 指的是 最外侧的维度,而 原张量 a 的形状是(2,3,4),其最外侧维度的维数为2,所以索引index中所有元素只能是 0 或 1

index2 = torch.LongTensor([[[1,0,1,0],
                           [1,0,0,1],
                           [0,0,1,1]],
                          [[0,1,0,1],
                           [1,1,0,1],
                           [0,1,1,0]]])
d = torch.gather(a,0,index2)  
print(d)
'''
tensor([[[12,  1, 14,  3],
         [16,  5,  6, 19],
         [ 8,  9, 22, 23]],

        [[ 0, 13,  2, 15],
         [16, 17,  6, 19],
         [ 8, 21, 22, 11]]])
'''

当索引index中元素为0时,指的是此处替换a[0]中相同位置的值;
当索引index中元素为1时,指的是此处替换a[1]中相同位置的值。
简单来说,对于index[0]

tensor([[1,0,1,0],
        [1,0,0,1],
        [0,0,1,1]])

聚合后的张量output[0]为

output[0] =[[[第2维,第1维,第2维,第1维],
            [第2维,第1维,第1维,第2维],
            [第1维,第1维,第2维,第2维]] 

即对于索引张量index的第一行元素[1,0,1,0],依次指

  • 在输出张量d[0] 第一行元素中,第一个元素为a[1]对应位置的元素,即12;
  • 在输出张量d[0] 第一行元素中,第二个元素为a[0]对应位置的元素,即1;
  • 在输出张量d[0] 第一行元素中,第三个元素为a[1]对应位置的元素,即14;
  • 在输出张量d[0] 第一行元素中,第四个元素为a[0]对应位置的元素,即3;

对于二维张量,此函数只针对列(0)或行(1)进行元素聚合

import torch
a = torch.arange(4,10).view(2,3)
print(a)
'''
tensor([[4, 5, 6],
        [7, 8, 9]])
'''
index = torch.LongTensor([[1,0,2],
                          [2,0,1]])
print(a.size()==index.size())  # True
b = torch.gather(a,1,index)
print(b)
'''
tensor([[5, 4, 6],
        [9, 7, 8]])
'''
index2 = torch.LongTensor([[1,0,1],
                          [0,0,1]])
c = torch.gather(a,0,index2)  
print(c)
'''
tensor([[7, 5, 9],
        [4, 5, 9]])
'''

你可能感兴趣的:(PyTorch)