torch.gather()
函数解析在使用过程中,对于gather的具体使用方法不甚了解,经过一番测试之后,终于弄懂了其使用方法,这里进行整理。
首先看一下这个方法的参数:
torch.gather(inputs, dim, index, *, sparse_grad=False, out=None) → Tensor
关键参数为inputs, dim, index
,我将会按照自己的理解在后面详细解释每个参数的含义,这里只需要有一个简单的印象即可。
For a 3-D tensor the output is specified by:
out[i][j][k] = inputs[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = inputs[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = inputs[i][j][index[i][j][k]] # if dim == 2
其实这里的式子以及解释的很清楚了,奈何本人愚钝,很久才理解…
这个式子具体的含义其实很简单,就是对于给定的inputs,将参数dim
所指定维度的索引值用参数index
中相应的元素代替表示(因此index的维度必须和inputs相等,因为需要索引到inputs中每个元素),其余位置的索引按照顺序即0,1,2…
这么说可能比较模糊,我这里给出一个具体的例子:
import torch
inputs = torch.rand(2,2,2)
inputs, inputs.shape
# (tensor([[[0.2465, 0.5183],
# [0.7851, 0.9985]],
# [[0.8217, 0.2534],
# [0.0478, 0.4805]]]),
# torch.Size([2, 2, 2]))
index = torch.tensor([[[1, 0],[0, 0]]])
index, index.shape
# (tensor([[[1, 0],
# [0, 0]]]),
# torch.Size([1, 2, 2]))
这里的inputs的索引,可以表示为inputs[i][j][k]
,例如inputs[0][1][1]=0.9985
则inputs中的每个元素的索引分别为:
inputs[0][0][0], inputs[0][0][1]
inputs[0][1][0], inputs[0][1][1]
inputs[1][0][0], inputs[1][0][1]
inputs[1][1][0], inputs[1][1][1]
当我们指定dim=0
时,调用gather()
方法:
inputs.gather(0, index)
# tensor([[[0.4327, 0.2616],
# [0.7656, 0.1954]]])
也就是说,对于inputs
的第0维,我们将其索引i
用张量index
中的每个元素代替
例如第一个结果0.04327
,我们知道它在结果矩阵的第一个位置上,因此它的j=0,k=0
,原本该位置的i=0
,经过gather()
方法之后,替换为i=index[0][0][0]
,而index[0][0][0]=1
,故该位置的元素为源矩阵[1][0][0]
位置的元素,inputs[1][0][0]=0.4327
同理可得到每个位置的索引分别为:
inputs[1][0][0], inputs[0][0][1]
inputs[0][1][0], inputs[0][1][1],
也就是说,对于inputs
中dim=0
位置的索引i
,我们将其原本的顺序替换为index
参数中的值。
我们再试一下令dim=1
:
inputs.gather(1, index)
# tensor([[[0.7656, 0.2616],
# [0.2183, 0.2616]]])
每个位置的索引即为:
inputs[0][1][0], inputs[0][0][1]
inputs[0][0][0], inputs[0][0][1],
一句话概括gather()
方法的作用就是,对于给定的inputs,将参数dim
所指定维度的索引值用参数index
中相应的元素代替表示。