图解torch.gather()的用法

最近在看别人的代码时,看到一个神奇的函数torch.gather(),网上查了一下,看了半天才终于看懂对方在说什么,个人觉得这种东西应该画个图来帮助理解,于是就写了这篇博客。

官方文档解释:

torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters:	

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

二维情况

以二维矩阵为例,代码如下:

b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))

输出如下:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
        
tensor([[1., 2.],
        [6., 4.]])
        
tensor([[1., 5., 6.],
        [1., 2., 3.]])

启动ppt来画图解释一下。

首先,我们有一个输入的二维矩阵:
图解torch.gather()的用法_第1张图片

我们对它,按行、列都标上索引:
图解torch.gather()的用法_第2张图片

dim = 0

dim=0时,是按行来,根据索引来决定两行里面选择哪一行:

选择第0行,填入数字1
图解torch.gather()的用法_第3张图片

图解torch.gather()的用法_第4张图片

选择第1行,填入数字5
图解torch.gather()的用法_第5张图片
图解torch.gather()的用法_第6张图片

选择第1行,填入数字6
图解torch.gather()的用法_第7张图片
图解torch.gather()的用法_第8张图片

选择第0行,填入数字1
图解torch.gather()的用法_第9张图片
图解torch.gather()的用法_第10张图片
选择第0行,填入数字2
图解torch.gather()的用法_第11张图片
图解torch.gather()的用法_第12张图片
选择第0行,填入数字3
图解torch.gather()的用法_第13张图片
图解torch.gather()的用法_第14张图片

dim = 1

dim=1时,是按列来,根据索引来决定 3列里面选择哪一列:

选择第0列,填入数字1
图解torch.gather()的用法_第15张图片
图解torch.gather()的用法_第16张图片
选择第1列,填入数字2
图解torch.gather()的用法_第17张图片
图解torch.gather()的用法_第18张图片
选择第2列,填入数字6
图解torch.gather()的用法_第19张图片
图解torch.gather()的用法_第20张图片
选择第0列,填入数字4
图解torch.gather()的用法_第21张图片
图解torch.gather()的用法_第22张图片

完毕,与输出一模一样。

dim = 2

那么,tensor是三维的情况怎么办?处理方法如下图所示。
图解torch.gather()的用法_第23张图片

参考:Pytorch中的torch.gather函数的含义

你可能感兴趣的:(❤️,机器学习)