参考文章:Pytorch中的torch.gather函数的含义
demo
b = torch.Tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(b.gather(dim=1, index=index_1))#print(torch.gather(b, dim=1, index=index_1))
print(b.gather(dim=0, index=index_2))
gather函数的功能可以解释为根据 index 参数(即是索引)返回数组里面对应位置的值
这里的b.gather()写法和torch.gather(b)的写法都可以,重点是两个参数,dim和index
dim=0表示按行来索引,也就是说index的值表示的是第几行
dim=1表示按列来索引,也就是指index的值表示的是第几列
b.gather(dim=1, index=index_1)
可以看到index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])
是一个3行2列的矩阵,根据dim=1,index_1里面的值表示的就是第几列,第几行就由index_1决定(共3行),那么[0,1]表示的就是【第0行第0列,第0行第1列】;[2,0]表示【第1行第2列,第1行第0列】, [1,1]表示【第3行第1列,第3行第1列】
b.gather(dim=0, index=index_2)
可以看到index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
是一个2行3列的矩阵,根据dim=0,index_2里面的值表示的就是第几行,第几列就由index_2决定(共3列),那么[0,1,1]表示的就是【第0行第0列,第1行第1列,第1行第2列】;[0,0,0]表示【第0行第0列,第0行第1列,第0行第2列】
'''b'''
######### 0列 1列 2列 3列
tensor([[ 1., 2., 3., 4.],# 0行
[ 5., 6., 7., 8.],# 1行
[ 9., 10., 11., 12.]])# 2行
'''b.gather(dim=1, index=index_1)'''
tensor([[ 1., 2.],
[ 7., 5.],
[10., 10.]])
'''b.gather(dim=0, index=index_2)'''
tensor([[1., 6., 7.],
[1., 2., 3.]])
b.gather(dim=1, index=index_1)
可以看到index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])
是一个3行2列的矩阵,index_1的[0,1]中的0的索引是(0,0),1的索引是(0,1);[2,0]中的2索引是(1,0),0的索引是(1,1);[1,1]中左边1的索引是(2,0),右边1的索引是(2,1)。然后根据dim=1,需要把这些索引的dim=1维度的值全部替换成对应index_1中的值,操作如下:
[0,1]中的0的索引是(0,0)转变为(0,0),1的索引是(0,1)转变为(0,1)
[2,0]中的2索引是(1,0)转变为(1,2),0的索引是(1,1)转变为(1,0)
[1,1]中左边1的索引是(2,0)转变为(2,1),右边1的索引是(2,1)转变为(2,1)
转变之后的索引对应到b上,把对应索引的数值取出来即可