torch.scatter(input, dim, index, src)
src
. When empty, the operation returns self
unchanged.'add'
or 'multiply'
.将src中的数据根据index中的索引按照dim的方向填入到input中
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.
看了上述的官方文档还是不理解,我们继续看官方的例子,这里官方只给了三维,我在这里又加入了二维,在这之前有一个规定
index.size(d) <= src.size(d)
index.size(d) <= self.size(d)
self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1
先上代码
torch.manual_seed(0)
x = torch.arange(0, 12).reshape(2, 6)
x= x.type(torch.float32)
print(x)
'''
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
'''
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index)
'''
tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
'''
y = torch.zeros(3, 6)
y = torch.scatter(y, 0, index, x)
print(y)
'''
tensor([[ 0., 7., 8., 3., 4., 0.],
[ 0., 1., 0., 9., 0., 0.],
[ 6., 0., 2., 0., 10., 0.]])
'''
'''in-place operation'''
yy = torch.zeros(3, 6)
yy.scatter_(0, index, x)
print(yy)
'''
tensor([[ 0., 7., 8., 3., 4., 0.],
[ 0., 1., 0., 9., 0., 0.],
[ 6., 0., 2., 0., 10., 0.]])
'''
下面将上述的执行过程绘制出来
首先下面是我们的初始化,我们初始化了
src
,index
,input
并且设置了dim=0
填充公式为self [ index[i][j] ][j] = src[i][j]
因为dim=0
,所以需要填充的input的行的索引就由index数值也就是index[i][j]来确定,需要填充的input的列的索引对应于index的列,往self里面填充的具体数值由index对应的src来确定
看下面例子序号3:
行=index[i][j]=index[0][2]=2
列=index列=j=2
self[行][列]=self[2][2]=src[i][j]=2
下面我们继续上述步骤
可发现当我们进行到第六步的时候,index[0][5]
并不存在,所以直接跳过就可以了
在这一步我们将input
填充完毕
如图所示,这里我们取图中的序号11进行验证
序号11:
行=index[i][j]=index[1][4]=2
列=index列=j=4
self[行][列]=self[2][4]=src[i][j]=10
为什么在第二步我们遇到的问题吗:当我们进行到序号6的时候,index[0][5]
并不存在,我们选择了跳过
可以跳过而没有报错呢,因为最初的文档对src, index, self
的维度有过定义
index.size(d) <= src.size(d)
index.size(d) <= self.size(d)