Pytorch中torch.gather函数

在学习 CS231n中的NetworkVisualization-PyTorch任务,讲解了使用torch.gather函数,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。
其中 gather有两种使用方式,一种为 torch.gather
另一种为 对象.gather。

首先介绍 对象.gather

import torch

torch.manual_seed(2)   #为CPU设置种子用于生成随机数,以使得结果是确定的 
def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])  # 必须要为 LongTensor 不然会报错
    print(s)
    print(y)
    print(s.gather(1,y.view(-1, 1)).squeeze())
    

gather_example()


'''
输出:

tensor([[-1.0408,  0.9166, -1.3042, -1.1097,  0.0299],
        [-0.0498,  1.0651,  0.8860, -0.8110,  0.6737],
        [-1.1233, -0.0919,  0.1405,  1.1191,  0.3152],
        [ 1.7528, -0.7396, -1.2425, -0.1752,  0.6990]])
tensor([1, 2, 1, 3])
tensor([ 0.9166,  0.8860, -0.0919, -0.1752])

'''

对于上图的代码,首先通过 torch.randn 随机输出化出结果为

tensor([[-1.0408,  0.9166, -1.3042, -1.1097,  0.0299],
        [-0.0498,  1.0651,  0.8860, -0.8110,  0.6737],
        [-1.1233, -0.0919,  0.1405,  1.1191,  0.3152],
        [ 1.7528, -0.7396, -1.2425, -0.1752,  0.6990]])

然后 我们根据索引 tensor([1, 2, 1, 3]) 对每一行进行索引,在第0行索引到位置=1的元素,即 0.9166,在第二行索引到位置=2的元素即 0.8860 以此类推,即为最后的结果。

另一种为 torch.gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print (b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1)) # 按行来进行索引
print (torch.gather(b, dim=0, index=index_2)) # 按列来进行索引


'''
输出为:

1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]


 1  2
 6  4
[torch.FloatTensor of size 2x2]


 1  5  6
 1  2  3
[torch.FloatTensor of size 2x3]

'''

当dim=0时候,那么就是按照行进行索引,输出第一行位置为0,1的元素,即1,2 。第二行位置为2,0的元素,即 6,4。

当dim=1时候,那么就是按照列进行索引,输出第一列位置为0,第二列位置为1,第三列位置为1的元素,即1,5,6。输出第二列位置为0,第二列位置为0,第三列位置为0的元素,即1,2,3。

综上,总结一下,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。你可以选择按照行和列的位置进行索引。

你可能感兴趣的:(Deep,Learning,python,深度学习)