可以先看torch官方文档介绍
主要作用是根据索引值index,找出input向量中指定dim维度所对应的数值,熟练使用该函数,就不用暴力for循环啦。
(1)自然语言处理中的mask与padding位置
import torch
a= torch.Tensor([
[4,1,2,0,0],
[2,4,0,0,0],
[1,1,1,6,5],
[1,2,2,2,2],
[3,0,0,0,0],
[2,2,0,0,0]])
index = torch.LongTensor([[3],[2],[5],[5],[1],[2]])
print(a.size(),index.size())
b = torch.gather(a, 1,index-1)
print(b)
注意维度的选择:
(2)神经网络特定类的预测值
import torch
a= torch.Tensor([
[0.4,0.1,0.2,0.,0.3],
[0.2,0.4,0.2,0.1,0.1],
[0.1,0.1,0.1,0.6,0],
[0.1,0.3,0.2,0.2,0.2],
[0.3,0.0,0.7,0,0],
[0.2,0.2,0.0,0.5,0.1]])
index = torch.LongTensor([[1],[2],[4],[2],[3],[4]])
print(a.size(),index.size())
b = torch.gather(a, 1,index-1)
print(b)
输出: