首先我要吐槽torch.gather()函数的官方文档,请问它在说个啥?后来根据csdn以及自己的学习,总结出torch.gather()的用法:
首先,给出torch.gather()中的几个参数:
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
常用的就是input,dim,index
三个参数:
[ ]
表示一个维度,比如[ [ 2,3 ] ]
中的2和3就是在第二维,dim可以取0,1,2;说了这么多估计也没说明白,正常正常,先上几个例子自己理解理解:
input=torch.arange(15).view(3,5)
print("input:\n",input)
index1=torch.tensor([
[1, 0],
[0, 0],
[1, 2]])
print("index:\n",index)
print("dim=1时:\n",torch.gather(input,dim=1,index=index1))
结果为:
input:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
index:
tensor([[1, 0],
[0, 0],
[1, 2]])
dim=1时:
tensor([[ 1, 0],
[ 5, 5],
[11, 12]])
dim=1表示取第二维,也就是第二个中括号中的元素进行处理,仔细观察index:
[1,0]
中的1表示在input中第二维下标为1的元素,也就是1[1,0]
中的0表示在input中第二维下标为1的元素,也就是0[1,2]
中的1表示在input中第二维第三组值下标为1的元素。也就是11[1,2]
中的2表示在input中第二维第三组值下标为2的元素。也就是12注意: dim=1时,index中组的个数要与input组的个数相同
input=torch.arange(15).view(3,5)
print("input:\n",input)
index1=torch.tensor([
[1,0,0,0,0],
[0,0,1,2,1],
[1,2,0,0,0]])
print("index:\n",index1)
print("dim=0时:\n",torch.gather(input,dim=0,index=index1))
输出结果:
input:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
index:
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 2, 1]])
dim=0时:
tensor([[ 5, 1, 2, 3, 4],
[ 0, 1, 7, 13, 9]])
当dim=0时,表示在第一维中检索下标,input第一维度的数据可以看作:
[0,5,10],[1,6,11],[2,7,12],[3,8,13],[4,0,14]
那么在index中的[1,0,0,0,0]
中的1表示在[0,5,10]
中取下标为1的元素,也就是5。后面可以依次取出并集合到一个tensor中去。
注意: dim=0时,index每一组中的元素个数要与input中的元素个数相同,也就是都为5个。
相信对torch.gather()有一定的了解,那么在下面举例dim=3的情况:
input=torch.tensor([[
[1,2,3],
[4,5,6],
[7,8,9]]
])
index1=torch.tensor([[
[0,0],
[0,0],
[0,0]
]])
print("input:\n",input)
print("index:\n",index1)
print("dim=3时:\n",torch.gather(test,dim=2,index=index1))
结果为:
input:
tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
index:
tensor([[[0, 0],
[0, 0],
[0, 0]]])
dim=3时:
tensor([[[1, 1],
[4, 4],
[7, 7]]])
以上就是有关gather()的笔记,不知道我是否有讲清楚,有帮助的点个赞吧~