input(tensor): 待操作数。不妨设其维度为(x1, x2, …, xn)
dim(int): 待操作的维度。
index(LongTensor): 如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y, …,xn),既是将input的第i维的大小更改为y,且要满足y>=1(除了第i维之外的其他维度,大小要和input保持一致)。
out: 注意输出和index的维度是一致的
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
在序列标注问题上,我们给每一个单词都标上一个标签。不妨假设我们有4个句子,每个句子的长度不一定相同,标签如下:
input = [
[2, 3, 4, 5],
[1, 4, 3],
[4, 2, 2, 5, 7],
[1]
]
上例中有四个句子,长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。我们知道,处理自然语言问题时,一般都需要进行padding,即将不同长度的句子padding到同一长度,以0为padding,那么上述经padding后变为:
input = [
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
那么问题来了,现在我们想获得每个句子中最后一个词语的标签,该怎么得到呢?既是,第一句话中的5,第二句话中的3,第三句话中7,第四句话中的1。
此时就需要用gather函数了(当然你说可以循环什么的,当我没问)。
此时我们的input就是填充之后的tensor,dim=1, index就是各个句子的长度,即[[4],[3],[5],[1]]。之所以维度是4*1,是为了满足index维度和input维度之间的关系(讲解参数时有讲)。
import torch
input = [
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
length = torch.LongTensor([[4],[3],[5],[1]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, length-1)
out
此函数的作用感觉一句话说不出来,硬说的话,我感觉应该是:
利用index来索引input特定位置的数值
例如上例中的length,再加上dim=1,指定了索引每句话中的最后一个单词(length-1)。
另外可以琢磨一下gather的计算公式