torch.gather(input, dim, index, out=None) → Tensor
沿给定轴,按照索引张量将原张量的指定位置的元素重新聚合成一个新的张量
参数含义:
官方给出的解释是这样的:
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out [i] [j] [k] = tensor [index[i][j][k]] [j] [k] # dim=0
out [i] [j] [k] = tensor [i] [index[i][j][k]] [k] # dim=1
out [i] [j] [k] = tensor [i] [j] [index[i][j][k]] # dim=2
对一个2维张量,输出可以定义为:
out [i] [j] = tensor [index[i][j]] [j] # dim=0
out [i] [j] = tensor [i] [index[i][j]] # dim=1
刚开始看上去很难理解,但经过研究之后会发现这个想表述的意思很简单,先给出几个代码例子让大家自行体会一下
先以3维张量为例
例 1 、维度 dim = 1
a = torch.arange(24).view(2,3,4)
print(a)
'''
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
'''
index = torch.LongTensor([[[1,0,1,2],
[2,0,2,1],
[0,2,1,1]],
[[2,1,0,1],
[1,2,0,2],
[0,2,1,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 4, 1, 6, 11],
[ 8, 1, 10, 7],
[ 0, 9, 6, 7]],
[[20, 17, 14, 19],
[16, 21, 14, 23],
[12, 21, 18, 23]]])
'''
解析:
在本例中,指定维度dim=1,那么就是从 列 来根据索引排列元素。比如 对于索引张量index的第一行元素[1,0,1,2],依次指
简单来说,对于 index[0]
tensor([[1, 0, 1, 2],
[2, 0, 2, 1],
[0, 2, 1, 1]])
聚合后的张量output[0]为
output[0] =[[[第一列第2个,第二列第1个,第三列第2个,第四列第3个],
[第一列第3个,第二列第1个,第三列第3个,第四列第2个],
[第一列第1个,第二列第3个,第三列第2个,第四列第2个]]
例2、维度dim = 2
c = torch.gather(a,2,index)
print(c)
'''
tensor([[[ 1, 0, 1, 2],
[ 6, 4, 6, 5],
[ 8, 10, 9, 9]],
[[14, 13, 12, 13],
[17, 18, 16, 18],
[20, 22, 21, 22]]])
'''
简单来说,对于 index[0]
tensor([[1, 0, 1, 2],
[2, 0, 2, 1],
[0, 2, 1, 1]])
聚合后的张量output[0]为
output[0] =[[[第一行第2个,第一行第1个,第一行第2个,第一行第3个],
[第二行第3个,第二行第1个,第二行第3个,第二行第2个],
[第三行第1个,第三行第3个,第三行第2个,第三行第2个]]
例3、维度 dim = 0
dim = 0 指的是 最外侧的维度,而 原张量 a 的形状是(2,3,4),其最外侧维度的维数为2,所以索引index中所有元素只能是 0 或 1
index2 = torch.LongTensor([[[1,0,1,0],
[1,0,0,1],
[0,0,1,1]],
[[0,1,0,1],
[1,1,0,1],
[0,1,1,0]]])
d = torch.gather(a,0,index2)
print(d)
'''
tensor([[[12, 1, 14, 3],
[16, 5, 6, 19],
[ 8, 9, 22, 23]],
[[ 0, 13, 2, 15],
[16, 17, 6, 19],
[ 8, 21, 22, 11]]])
'''
当索引index中元素为0时,指的是此处替换a[0]中相同位置的值;
当索引index中元素为1时,指的是此处替换a[1]中相同位置的值。
简单来说,对于index[0]
tensor([[1,0,1,0],
[1,0,0,1],
[0,0,1,1]])
聚合后的张量output[0]为
output[0] =[[[第2维,第1维,第2维,第1维],
[第2维,第1维,第1维,第2维],
[第1维,第1维,第2维,第2维]]
即对于索引张量index的第一行元素[1,0,1,0],依次指
对于二维张量,此函数只针对列(0)或行(1)进行元素聚合
import torch
a = torch.arange(4,10).view(2,3)
print(a)
'''
tensor([[4, 5, 6],
[7, 8, 9]])
'''
index = torch.LongTensor([[1,0,2],
[2,0,1]])
print(a.size()==index.size()) # True
b = torch.gather(a,1,index)
print(b)
'''
tensor([[5, 4, 6],
[9, 7, 8]])
'''
index2 = torch.LongTensor([[1,0,1],
[0,0,1]])
c = torch.gather(a,0,index2)
print(c)
'''
tensor([[7, 5, 9],
[4, 5, 9]])
'''