在学习 CS231n中的NetworkVisualization-PyTorch任务,讲解了使用torch.gather函数,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。
其中 gather有两种使用方式,一种为 torch.gather
另一种为 对象.gather。
首先介绍 对象.gather
import torch
torch.manual_seed(2) #为CPU设置种子用于生成随机数,以使得结果是确定的
def gather_example():
N, C = 4, 5
s = torch.randn(N, C)
y = torch.LongTensor([1, 2, 1, 3]) # 必须要为 LongTensor 不然会报错
print(s)
print(y)
print(s.gather(1,y.view(-1, 1)).squeeze())
gather_example()
'''
输出:
tensor([[-1.0408, 0.9166, -1.3042, -1.1097, 0.0299],
[-0.0498, 1.0651, 0.8860, -0.8110, 0.6737],
[-1.1233, -0.0919, 0.1405, 1.1191, 0.3152],
[ 1.7528, -0.7396, -1.2425, -0.1752, 0.6990]])
tensor([1, 2, 1, 3])
tensor([ 0.9166, 0.8860, -0.0919, -0.1752])
'''
对于上图的代码,首先通过 torch.randn 随机输出化出结果为
tensor([[-1.0408, 0.9166, -1.3042, -1.1097, 0.0299],
[-0.0498, 1.0651, 0.8860, -0.8110, 0.6737],
[-1.1233, -0.0919, 0.1405, 1.1191, 0.3152],
[ 1.7528, -0.7396, -1.2425, -0.1752, 0.6990]])
然后 我们根据索引 tensor([1, 2, 1, 3]) 对每一行进行索引,在第0行索引到位置=1的元素,即 0.9166,在第二行索引到位置=2的元素即 0.8860 以此类推,即为最后的结果。
另一种为 torch.gather
b = torch.Tensor([[1,2,3],[4,5,6]])
print (b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1)) # 按行来进行索引
print (torch.gather(b, dim=0, index=index_2)) # 按列来进行索引
'''
输出为:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
'''
当dim=0时候,那么就是按照行进行索引,输出第一行位置为0,1的元素,即1,2 。第二行位置为2,0的元素,即 6,4。
当dim=1时候,那么就是按照列进行索引,输出第一列位置为0,第二列位置为1,第三列位置为1的元素,即1,5,6。输出第二列位置为0,第二列位置为0,第三列位置为0的元素,即1,2,3。
综上,总结一下,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。你可以选择按照行和列的位置进行索引。