Pytorch中torch.gather函数祥解


原创申明:本文为作者原创,转载请注明出处!


引言:在多分类中,torch.gather常用来取出标签所对应的概率,但对于刚开始接触Pytorch的同学来说,torch.gather()可能不太好理解,这里做一些说明和演示,帮助理解。


  1. 官方说明
    gather( input', dim, index, out=None, sparse_grad=False)
    Gathers values along an axis specified by dim
    沿着给定的维度dim收集值
    Args: 参数(初学者可只看前三个参数)
    input (Tensor): the source tensor 源tensor(Tensor类型)
    dim (int): the axis along which to index 要进行索引的轴方向(int类型)
    index (LongTensor): the indices of elements to gather(LongTensor类型)
    out (Tensor, optional): the destination tensor 返回值(Tensor类型)
    sparse_grad(bool,optional): If True, gradient w.r.t. :attr:input will be a sparse tensor. 若为真,这关于input的梯度为sparse tensor
    注意:index的维度要和input中dim所指的维度相同

  1. 例子说明
    1) 按照dim = 0, 取一个2*2 tensor的对角线上的数值
#按照dim = 0,  取一个2*2tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])

b = torch.gather(a, dim = 0, index=torch.LongTensor([[0, 1]]))

print('a = ', a)
print('b = ', b)

输出如下:

a =  tensor([[1., 2.],
             [3., 4.]])

b =  tensor([[1., 4.]])

说明:
可以看到a的dim=0, 即方向的维度和index的维度是匹配的,就是说a和index由行方向从左往右看,有2列,即有2个样本,行方向是匹配的。另外,函数输出的tensor和index大小相同。
上面代码的操作逻辑是:
在a中,由行看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4

2) 按照dim = 1, 取一个2*2 tensor的对角线上的数值

#按照dim = 1,  取一个2*2 tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])

c = torch.gather(a, dim = 1, index=torch.LongTensor([[0], 
                                                     [1]]))
print('a = ', a)
print('c = ', c)

输出如下:

a =  tensor([[1., 2.],
             [3., 4.]])

c =  tensor([[1.],
             [4.]])

说明:
可以看到a的dim=1, 即方向的维度和index的维度是匹配的,就是说a和index由列方向从上往下看,有2行,即有2个样本,列方向是匹配的。另外,函数输出的tensor和index大小相同。
上面代码的操作逻辑是:
在a中,由列看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4

3)更复杂一点的例子
index变为2*2的longtensor

#
import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])
d = torch.gather(a, dim= 0, index=torch.LongTensor([[0, 0],
                                                   [1, 0]]))
print('a = ', a)
print('d = ', d)

输出:

a =  tensor([[1., 2.],
             [3., 4.]])

d =  tensor([[1., 2.],
             [3., 2.]])

说明:
index可看做是行[[0, 0]] 和 [[1, 0]]的组合,从上往下,先[[0, 0]] 再[[1, 0]],根据例子1)中的逻辑可知输出为d。如果是dim = 1, 则index按照列[[0, 1]] T 和 [[0, 0]]T的组合(T表示转置),从左往右,先[[0, 1]] T 再 [[0, 0]]T,按照2)中的逻辑,得可输出。

  1. 实际中的一个例子
    有三个标签[0, 1, 2],即三个类别。现在知道两个样本(A 和 B)所得到的三个标签的概率分别为[0.1, 0.3, 0.6]和[0.3, 0.2, 0.5], 用myY_hat表示, 这两个样本的真实标签分别为0和2, 那么我们很容易知道A所预测的真实标签的概率为0.1, B所预测的真实标签的概率为0.5,A误分类,B正确分类。那么用程序这么获得标签对应的概率呢,这里就可以用gather函数。
myY_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
myY = torch.LongTensor([0, 2])
print(myY.view(-1, 1))
print(myY_hat.gather(1, myY.view(-1, 1)))

输出:

tensor([[0],
        [2]])
tensor([[0.1000],
        [0.5000]])

附:
Tensor的基本数据类型有五种:
32位浮点型:torch.FloatTensor,pyorch.Tensor()默认的就是这种类型。
64位整型:torch.LongTensor。
32位整型:torch.IntTensor。
16位整型:torch.ShortTensor。
64位浮点型:torch.DoubleTensor。

你可能感兴趣的:(Pytorch中torch.gather函数祥解)