gather(input, dim, index) | 根据index,在dim维度上选取数据,输出的size与index一样 |
文章摘自:一、PyTorch基础:Tensor和Autograd_白水小琪七ya的博客-CSDN博客
a = t.arange(0, 16).view(4, 4)
# 选取对角线的元素
index = t.LongTensor([[0,1,2,3]])
a.gather(0, index)
# 选取反对角线上的元素
index = t.LongTensor([[3],[2],[1],[0]])# 等价于index = t.LongTensor([[3,2,1,0]]).t()
a.gather(1, index)
# 选取反对角线上的元素,注意与上面的不同
index = t.LongTensor([[3,2,1,0]])
a.gather(0, index)
# 选取两个对角线上的元素
index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()
b = a.gather(1, index)
gather函数的功能可以解释为根据 index 参数(即是索引)返回数组里面对应位置的值
这里的a.gather()写法和torch.gather(a)的写法都可以,重点是两个参数,dim和index
简单的的理解方式
dim=0表示按行来索引,也就是说index的值表示的是第几行
dim=1表示按列来索引,也就是指index的值表示的是第几列
举例来讲:,a.gather(0, index),index = t.LongTensor([[0,1,2,3]]) 可以看出index = t.LongTensor([[0,1,2,3]]) 是一个1行4列的矩阵,根据dim=0,inde里面的值就是表示第几行,而第几列中是有index决定(4列),那么[0]就是表示0行0列 , [1]表示1行1列, [2]表示2行2列 [3]表示3行3列
index = t.LongTensor([[3,2,1,0]]).t() , a.gather(1, index) 可以看出index = t.LongTensor([[3,2,1,0]]).t() 是一个1行4列的矩阵,根据dim=1,index里面的值表示的就是第几列,第几行就由index决定(共4行),那么 [3] 表示第0行3列 ,[2]表示第1行2列, [1] 第2行1列 [0]第3行0列
结果执行:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
# 选取对角线的元素
tensor([[ 0, 5, 10, 15]])
# 选取反对角线上的元素
tensor([[ 3],
[ 6],
[ 9],
[12]]
# 选取反对角线上的元素,注意与上面的不同
tensor([[12, 9, 6, 3]])
# 选取两个对角线上的元素
tensor([[ 0, 3],
[ 5, 6],
[10, 9],
[15, 12]])