torch.gather(input, dim, index, out=None) 和 torch.scatter_(dim, index, src)是一对作用相反的方法
先来看torch.gather, 核心操作其实就是这样:
out[i][j][k] = input[index[i][j][k]] [j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
是对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index
官方文档给的例子是:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
具体过程就是这里的input = [[1,2],[3,4]], index = [[0,0],[1,0]], dim = 1, 则
out[0][0] = input[0][ index[0][0] ] = input[0][0] = 1
out[0][1] = input[0][ index[0][1] ] = input[0][0] = 1
out[1][0] = input[1][ index[1][0] ] = input[1][1] = 4
out[1][1] = input[1][ index[1][1] ] = input[1][0] = 3
torch.scatter_(dim, index, src)
核心操作:
self[ index[i][j][k] ][ j ][ k ] = src[i][j][k] # if dim == 0
self[ i ][ index[i][j][k] ][ k ] = src[i][j][k] # if dim == 1
self[ i ][ j ][ index[i][j][k] ] = src[i][j][k] # if dim == 2
这个就是对于src(或者说input)指定位置上的值,去分配给output对应索引位置,根据是index,所以其实把src放在左边更容易理解,官方给的例子如下:
x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
此例中,src就是x,index就是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dim=0
我们把src写在左边,把self写在右边,这样好理解一些,
但要注意是把src的值赋给self,所以用箭头指过去:
0.4319 = Src[0][0] ----->self[ index[0][0] ][0] ----> self[0][0]
0.6500 = Src[0][1] ----->self[ index[0][1] ][1] ----> self[1][1]
0.4080 = Src[0][2] ----->self[ index[0][2] ][2] ----> self[2][2]
0.8760 = Src[0][3] ----->self[ index[0][3] ][3] ----> self[0][3]
0.2355 = Src[0][4] ----->self[ index[0][4] ][4] ----> self[0][4]
0.2609 = Src[1][0] ----->self[ index[1][0] ][0] ----> self[2][0]
0.4711 = Src[1][1] ----->self[ index[1][1] ][1] ----> self[0][1]
0.8486 = Src[1][2] ----->self[ index[1][2] ][2] ----> self[0][2]
0.8573 = Src[1][3] ----->self[ index[1][3] ][3] ----> self[1][3]
0.1029 = Src[1][4] ----->self[ index[1][4] ][4] ----> self[2][4]
则我们把src也就是 x的每个值都成功的分配了出去,然后我们再把self对应位置填好,
剩下的未得到分配的位置,就填0补充。