torch.gather() 和torch.sactter_()的用法简析

torch.gather() 和torch.sactter_()的用法简析_第1张图片

 

torch.gather(input, dim, index, out=None)  和 torch.scatter_(dim, index, src)是一对作用相反的方法

 

先来看torch.gather, 核心操作其实就是这样:

out[i][j][k] = input[index[i][j][k]] [j][k]  # if dim == 0

out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1

out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index

官方文档给的例子是:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]

具体过程就是这里的input = [[1,2],[3,4]], index = [[0,0],[1,0]], dim = 1, 则

out[0][0] = input[0][ index[0][0] ] = input[0][0] = 1

out[0][1] = input[0][ index[0][1] ] = input[0][0] = 1

out[1][0] = input[1][ index[1][0] ] = input[1][1] = 4

out[1][1] = input[1][ index[1][1] ] = input[1][0] = 3

 

 

torch.scatter_(dim, index, src)

核心操作:

self[ index[i][j][k] ][ j ][ k ] = src[i][j][k]  # if dim == 0

self[ i ][ index[i][j][k] ][ k ] = src[i][j][k]  # if dim == 1

self[ i ][ j ][ index[i][j][k] ] = src[i][j][k]  # if dim == 2

这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:

x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

此例中,src就是x,index就是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],  dim=0

我们把src写在左边,把self写在右边,这样好理解一些,

但要注意是把src的值赋给self,所以用箭头指过去:

0.4319 = Src[0][0] ----->self[ index[0][0] ][0] ----> self[0][0]

0.6500 = Src[0][1] ----->self[ index[0][1] ][1] ----> self[1][1]

0.4080 = Src[0][2] ----->self[ index[0][2] ][2] ----> self[2][2]

0.8760 = Src[0][3] ----->self[ index[0][3] ][3] ----> self[0][3]

0.2355 = Src[0][4] ----->self[ index[0][4] ][4] ----> self[0][4]

 

0.2609 = Src[1][0] ----->self[ index[1][0] ][0] ----> self[2][0]

0.4711 = Src[1][1] ----->self[ index[1][1] ][1] ----> self[0][1]

0.8486 = Src[1][2] ----->self[ index[1][2] ][2] ----> self[0][2]

0.8573 = Src[1][3] ----->self[ index[1][3] ][3] ----> self[1][3]

0.1029 = Src[1][4] ----->self[ index[1][4] ][4] ----> self[2][4]

 

则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好

剩下的未得到分配的位置,就填0补充。

你可能感兴趣的:(pytorch)