最近在看别人的代码时,看到一个神奇的函数torch.gather(),网上查了一下,看了半天才终于看懂对方在说什么,个人觉得这种东西应该画个图来帮助理解,于是就写了这篇博客。
官方文档解释:
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
以二维矩阵为例,代码如下:
b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))
输出如下:
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[1., 2.],
[6., 4.]])
tensor([[1., 5., 6.],
[1., 2., 3.]])
启动ppt来画图解释一下。
当dim=0
时,是按行来,根据索引来决定两行里面选择哪一行:
选择第0行,填入数字1
选择第0行,填入数字2
选择第0行,填入数字3
当dim=1
时,是按列来,根据索引来决定 3列里面选择哪一列:
选择第0列,填入数字1
选择第1列,填入数字2
选择第2列,填入数字6
选择第0列,填入数字4
完毕,与输出一模一样。
参考:Pytorch中的torch.gather函数的含义