gather可以对一个Tensor进行聚合,声明为:
torch.gather(input, dim, index, out=None) → Tensor
一般来说有三个参数:输入的变量input、指定在某一维上聚合的dim、聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型)。
#参数介绍:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
#当输入为三维时的计算过程:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
#样例:
t = torch.Tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
# 1 1
# 4 3
#[torch.FloatTensor of size 2x2]
用下面的代码在二维上做测试,以便更好地理解
t = torch.Tensor([[1,2,3],[4,5,6]])
index_a = torch.LongTensor([[0,0],[0,1]])
index_b = torch.LongTensor([[0,1,1],[1,0,0]])
print(t)
print(torch.gather(t,dim=1,index=index_a))
print(torch.gather(t,dim=0,index=index_b))
输出为:
>>tensor([[1., 2., 3.],
[4., 5., 6.]])
>>tensor([[1., 1.],
[4., 5.]])
>>tensor([[1., 5., 6.],
[4., 2., 3.]])
由于官网给的计算过程不太直观,下面给出较为直观的解释:
对于index_a,dim为1表示在第二个维度上进行聚合,索引为列号,[[0,0],[0,1]]
表示结果的第一行取原数组第一行列号为[0,0]
的数,也就是[1,1]
,结果的第二行取原数组第二行列号为[0,1]
的数,也就是[4,5]
,这样就得到了输出的结果[[1,1],[4,5]]
。
对于index_b,dim为0表示在第一个维度上进行聚合,索引为行号,[[0,1,1],[1,0,0]]
表示结果的第一行第d(d=0,1,2)列取原数组第d列行号为[0,1,1]
的数,也就是[1,5,6]
,类似的,结果的第二行第d列取原数组第d列行号为[1,0,0]
的数,也就是[4,2,3]
,这样就得到了输出的结果[[1,5,6],[4,2,3]]
接下来以index_a为例直接用官网的式子计算一遍加深理解:
output[0,0] = input[0,index[0,0]] #1 = input[0,0]
output[0,1] = input[0,index[0,1]] #1 = input[0,0]
output[1,0] = input[1,index[1,0]] #4 = input[1,0]
output[1,1] = input[1,index[1,1]] #5 = input[1,1]
以下两种写法得到的结果是一样的:
r1 = torch.gather(t,dim=1,index=index_a)
r2 = t.gather(1,index_a)