函数torch.gather(input, dim, index, out=None) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
对一个 3 维张量,输出可以定义为:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
- out (Tensor, optional) – 目标张量
使用说明举例:
- dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 28., 22., 27., 0.]],
[[ 26., 10., 20., 29., 18.],
[ 5., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18., 26., 22., 1., 0.],
[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.]],
[[ 5., 29., 10., 0., 22.],
[ 26., 10., 20., 29., 18.],
[ 10., 29., 10., 0., 22.]]])
可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,
第一行分别是 按照index指定的,
input tensor的第一页
第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,第五列下标为2元素
index-->0,1,2,0,2 output--> 18., 26., 22., 1., 0.
'''
- dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18., 5., 7., 18., 7.],
[ 3., 3., 3., 3., 3.],
[ 28., 28., 28., 28., 28.]],
[[ 10., 20., 20., 20., 20.],
[ 5., 5., 5., 5., 5.],
[ 10., 10., 10., 10., 10.]]])
dim = 2的时候就安装 行 聚合了。参照上面的举一反三。
'''
- dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18., 10., 20., 1., 18.],
[ 3., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]],
[[ 26., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 29., 22., 27., 0.]]])
这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。
这里举的例子只有两页。所有index在0,1两个之间选择。
输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。
index [0,1,1,0,1]的意思就是。
在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。
'''