torch.gather函数通俗理解(一堆废话)

torch.gather函数通俗理解(一堆废话)


水平有限,理解可能有些偏差,勿怪。。。

废话先不说我们先看下官方文档上面怎么说的:

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
  • 沿给定轴,将输入索引指定位置的值进行聚合
  • 对一个 3 维张量,输出可以定义为如下形式
out[i][j][k] = input[index[i][j][k]] [j] [k]  # if dim == 0
out[i][j][k] = input[i] [index[i][j][k]] [k]  # if dim == 1
out[i][j][k] = input[i] [j] [index[i][j][k]]  # if dim == 2

看完了官方文档的解释,我:“嗯哼?你说啥嘞???”

那么我们先忽略文字解释,但要知道的是:torch.gather()函数可以理解为,根据index张量和index,将原先的Tensor映射为了另一个Tensor。映射后Tensor内部的值都是原先Tensor的值,只不过相对位置发生了变化。

上面这三行代码,可以比较清楚得告诉我们不同轴方向下映射的关系。如果是一个copy()的操作,那么索引一定是一一对应的,而这边gather()的操作下,在设定轴dim的位置上索引由index决定。

原谅我讲了这么多废话。

下面看下例子吧~

import torch
import numpy as np
a = torch.LongTensor(np.arange(24)).view(2,3,4)
print(a)
'''
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
'''

index = torch.LongTensor([[[0 ,1 ,2 ,0],
                          [0, 0, 0 ,0],
                          [1, 1, 1, 1]],
                        [[2, 2, 2, 2],
                         [1, 1, 1, 1],
                         [0, 0, 0, 0]]])

b = torch.gather(a, 1, index)
print(b)
"""
tensor([[[ 0,  5, 10,  3],
         [ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[20, 21, 22, 23],
         [16, 17, 18, 19],
         [12, 13, 14, 15]]])
"""

我们设定dim=1,这时候映射的情况下,改变的只有第1个维度。对于例子中这个三维的张量a,对a[0]a[1]两块,我们分别取出a[0]a[1]的每一列,也就是[0,4,8][1,5,9]……[15,19,23]每一列值之间第0个维度的索引和第2个维度的索引都是固定不变的,而改变的,只有dim=1的维度。这也是我们设定的。

同样的,取出index的这些列,也就是[0,0,1][1,0,1]……[2,1,0],这些值恰巧就是 映射后的张量 每一列值的对映射前每一列值的索引(表示不太清楚)。
具体的,我们拿出张量b的这些列(b是映射后的张量),也就是[0,0,4][5,1,5]……[23,19,15]
我们拿出aindexb,的第一列。分别是[0,4,8][0,0,1][0,0,4],可以发现,index的列对a列的索引,刚好就是b的列。其余列也是一一对应。
如果用代码表示的话:

x,y,z = [0, 4, 8], [0, 0, 1], [0, 0, 4]
print(np.array(z) == np.array(x)[y])
"""
[ True  True  True] 
"""

而这一切成立的条件就是,我们设定了dim=1,而我们取出的每一列的值,变化的维度也是dim=1。。

如果dim=0,就应该取出a[0]a[1]两块一一对应的值,例如a中的:[0,12],[1,13]……
如果dim=2,就应该取出每一行的值。。

大体意思就是这样,虽然讲的还不是很清楚,但是允许我烂尾一下(好像也没有好的开头不能这么讲哈哈哈,随意拉~)。。

你可能感兴趣的:(pytorch,gather)