百度了一圈gather的用法,看的一知半解,结合了几篇博客的讲解,终于理解了这个的用法,记录下来,用于以后忘记的时候自己可以快速复习,同时不懂得小伙伴也可以参考下我这得理解,或许能帮助到你!!!
torch.gather(input, dim, index, out=None) → Tensor
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
input :需要索引的 tensor
dim : 指索引的维度 (0代表横向 1代表纵向 以此类推)
index: 索引的下标
import torch
b = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))
输出:
tensor([[1., 2.],
[6., 4.]])
tensor([[1., 5., 6.],
[1., 2., 3.]])
input : b =
1,2,3
4,5,6
dim = 1 :代表的是维度1也就是列
index =
0,1
2,0
了解了输入后我们分步进行解析
(,0),(,1)
(,2),(,0)
这样我们就完成了每个输出所在input中的坐标的列的定位
(0,0),(0,1)
(1,2),(1,0)
1,2
6,4
input : b =
1,2,3
4,5,6
dim = 0 :代表的是维度0也就是行
index =
0,1,1
0,0,0
(0,),(1,),(1,)
(0,),(0,),(0,)
(0,0),(1,1),(1,2)
(0,0),(0,1),(0,2)
1,5,6
1,2,3
gather的用法就是index所提供要索引的dim维的位置,其余维度的位置也就是index对应的位置 ,也就是输出的坐标,把dim维的替换成index中对应的数字 。
还不理解的话,再举个官方的例子:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
index 的中每个元素的坐标为:
(0,0),(0,1)
(1,0),(1,1)
dim= 1 ,也就是把第二个维度的坐标替换成index中的值
(0,0),(0,0)
(1,1),(1,0)
最后写出对应input中的值
1,1
4,3
如果还不懂的话,推荐一个博客,看看别人的讲解吧:https://blog.csdn.net/edogawachia/article/details/80515038