pytorch中的gather和scatter函数

最近看代码时遇到了两个函数,查阅pytorch官方文档后一时半会儿也没弄懂,现在写篇笔记来加深一下印象。

gather

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
沿着给定的维度dim,将输入input指定位置的值聚合起来,指定位置由index决定。
indexinput必须有相同数量的维度,且满足1 <= index[dim] <= input[dim]index[other_dims] == input[other_dims]
对于3维的张量,公式为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

下面给出一个例子:

import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
b = torch.gather(a, 1,index)
print(b)

'''
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
当dim=1时:
out[0][0][0] = out[0][index[0][0][0]][0] = out[0][0][0] = 18
out[0][0][1] = out[0][index[0][0][1]][1] = out[0][1][1] = 26
index中的值指示着input第1维中的位置.out第0维和第2维的位置相对于input不变.
index[0]第一行为0,1,2,0,2, out的列顺序不变,由此可知out[0]第一行为18,26,22,1,0,
可以描述为分别取input第0行、1行、2行、0行、2行中对应列的值.
'''

c = torch.gather(a, 2,index)
print(c)

'''
tensor([[[ 18.,   5.,   7.,  18.,   7.],
         [  3.,   3.,   3.,   3.,   3.],
         [ 28.,  28.,  28.,  28.,  28.]],

        [[ 10.,  20.,  20.,  20.,  20.],
         [  5.,   5.,   5.,   5.,   5.],
         [ 10.,  10.,  10.,  10.,  10.]]])
当dim=2时:
out[0][0][0] = out[0][0][index[0][0][0]] = out[0][0][0] = 18
out[0][1][0] = out[0][1][index[0][1][0]] = out[0][1][0] = 3
index中的值指示着input第2维中的位置.out第0维和第1维的位置相对于input不变.
index[0]第一列的值为0,0,1,out的行顺序不变,由此可知out[0]第一列为18,3,28,
可以描述为分别取input第0列、第0列、第1列中对应行的值.
'''

index2 = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)

'''
tensor([[[ 18.,  10.,  20.,   1.,  18.],
         [  3.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]],

        [[ 26.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  29.,  22.,  27.,   0.]]])
         
当dim=0时:
注意index的值必须在[0,1]之间,因为index此时指示第0维的位置,out第1维和第2维的位置相对于input不变.
index[0]第一行为0,1,1,0,1,out的行和列顺序都不变,由此可知out[0]第一行为18,10,20,1,18,
可以描述为取input第0页、第1页、第1页、第0页、第1页中对应行、列位置的值。
'''

gather函数的应用:在KNN分类问题中,假设我们得到了B,N的张量sim_matrix代表测试数据与训练数据的余弦相似度,利用sim_matrix.topk()可以得到B,K的索引indices,训练数据的labels共有N个,我们可以利用torch.gather(labels.expand(B, -1), dim=-1, index=indices)来获取一个B,K的张量,它代表对于这B个测试数据,与其最相邻的K个训练数据的分类标签。(实际上labels[indices]这种写法也可以)

scatter

scatter_(dim, index, src) → Tensor

沿着给定的维度dim,按照src的值来修改self的指定位置,指定位置由index决定。
对于3维的张量,公式为:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

看起来就是gather函数的逆操作。
举例:

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
'''
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

dim=0时,index中的值指示着self第0维的位置,第1维的位置不变
index第一行为0,1,2,0,0, 这代表self第0行、第1行、第2行、第0行、第0行中对应列的值
被修改为x的第一行,即0.3992, 0.2908, 0.9044, 0.4850, 0.6004
同理可知index第二行对于self的修改
'''

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z

'''
tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2300]])
dim=1时,index指示self第1维的位置.因此我们可以看到self第0行第2列和第1行第3列被修改成了1.23
'''

scatter常用于生成one-hot标签:假设已有B,K的一个张量labels,其值在[0, C-1]之间,C为类别总数。我们可以利用torch.zeros(B*K, C).scatter(dim=-1, index=labels.view(-1, 1), value=1)来获取到B*K, Cone-hot标签。
与之作用相同的写法:

one_hot_label = torch.zeros(B * K, C)
one_hot_label[torch.tensor(range(0, B * K)), labels.view(-1, 1).squeeze()] = 1.0

你可能感兴趣的:(pytorch)