gather( input', dim, index, out=None, sparse_grad=False)
dim
input (Tensor)
: the source tensor 源tensor(Tensor类型)dim (int)
: the axis along which to index 要进行索引的轴方向(int类型)index (LongTensor)
: the indices of elements to gather(LongTensor类型)out (Tensor, optional)
: the destination tensor 返回值(Tensor类型)sparse_grad(bool,optional)
: If True
, gradient w.r.t. :attr:input
will be a sparse tensor. 若为真,这关于input的梯度为sparse tensor#按照dim = 0, 取一个2*2tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
b = torch.gather(a, dim = 0, index=torch.LongTensor([[0, 1]]))
print('a = ', a)
print('b = ', b)
输出如下:
a = tensor([[1., 2.],
[3., 4.]])
b = tensor([[1., 4.]])
说明:
可以看到a的dim=0, 即行方向的维度和index的维度是匹配的,就是说a和index由行方向从左往右看,有2列,即有2个样本,行方向是匹配的。另外,函数输出的tensor和index大小相同。
上面代码的操作逻辑是:
在a中,由行看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4
2) 按照dim = 1, 取一个2*2 tensor的对角线上的数值
#按照dim = 1, 取一个2*2 tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
c = torch.gather(a, dim = 1, index=torch.LongTensor([[0],
[1]]))
print('a = ', a)
print('c = ', c)
输出如下:
a = tensor([[1., 2.],
[3., 4.]])
c = tensor([[1.],
[4.]])
说明:
可以看到a的dim=1, 即列方向的维度和index的维度是匹配的,就是说a和index由列方向从上往下看,有2行,即有2个样本,列方向是匹配的。另外,函数输出的tensor和index大小相同。
上面代码的操作逻辑是:
在a中,由列看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4。
3)更复杂一点的例子
index变为2*2的longtensor
#
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
d = torch.gather(a, dim= 0, index=torch.LongTensor([[0, 0],
[1, 0]]))
print('a = ', a)
print('d = ', d)
输出:
a = tensor([[1., 2.],
[3., 4.]])
d = tensor([[1., 2.],
[3., 2.]])
说明:
index可看做是行[[0, 0]] 和 [[1, 0]]的组合,从上往下,先[[0, 0]] 再[[1, 0]],根据例子1)中的逻辑可知输出为d。如果是dim = 1, 则index按照列[[0, 1]] T 和 [[0, 0]]T的组合(T表示转置),从左往右,先[[0, 1]] T 再 [[0, 0]]T,按照2)中的逻辑,得可输出。
myY_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
myY = torch.LongTensor([0, 2])
print(myY.view(-1, 1))
print(myY_hat.gather(1, myY.view(-1, 1)))
输出:
tensor([[0],
[2]])
tensor([[0.1000],
[0.5000]])
4. 关于坐标(矩阵内位置)的简单获取(计算)方法:
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
输出:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
输出结果
tensor([[9, 7, 5]])
更多内容请参考 图解PyTorch中的torch.gather函数 - 知乎