更新tensor指定位置的值可以使用tensor. scatter_add_
(dim, index, src) //把src中的值加到指定的tensor上
tensor.scatter_(dim, index,src) //直接用src的值去替换原来tensor中的值。
官网上给出了一个例子:
>>> x = torch.rand(2, 5)
>>> x
tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
[0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
>>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
[1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
[1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
实际上最有用的是这个说明:
For a 3-D tensor,
self
is updated as:self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
我的情况是想用一个tensor去更新矩阵中对应的行。所以我使用的方法如下:
s = torch.tensor([[0.0,0.0,0.0],
[1.0,1.0,1.0],
[2.0,2.0,2.0],
[3.0,3.0,3.0]])
x = torch.tensor([[0.5,0.6,0.7],
[0.8,0.9,0.4]])
>>> s.scatter_add_(0, torch.tensor([[2,2,2],[3,3,3]]),x)
tensor([[0.0000, 0.0000, 0.0000],
[1.0000, 1.0000, 1.0000],
[2.5000, 2.6000, 2.7000],
[4.3000, 3.9000, 3.4000]])
对于3维的tensor,其直接更新tensor的值的方法如下:
>>> t = torch.rand(3,2,4)
>>> s = torch.randn(3,3,4)
>>> index = torch.tensor([[[0,0,0,0],[1,1,1,1]],[[1,1,1,1],[0,0,0,0]],[[0,0,0,0],[2,2,2,2]]])
#我这里想更新的是第一个tensor矩阵的第0,第1行,第二个tensor矩阵的第1,第0行,第三个tensor矩阵的第0和第2行。
>>> t
tensor([[[0.4599, 0.3912, 0.7227, 0.9956],
[0.7050, 0.9624, 0.3776, 0.6071]],
[[0.9987, 0.5976, 0.5211, 0.9859],
[0.8634, 0.9152, 0.4526, 0.8258]],
[[0.4119, 0.8666, 0.6350, 0.9806],
[0.7178, 0.2165, 0.6278, 0.7487]]])
>>> s
tensor([[[ 0.1415, 0.4588, 0.1907, 0.7518],
[ 0.4902, 0.1167, 0.5882, 0.2648],
[ 0.2600, 0.1659, 1.9888, -1.5446]],
[[ 0.5899, 0.9541, 0.4002, 0.5883],
[ 0.1144, 0.3988, 0.2115, 0.5314],
[ 1.2400, 1.1940, -2.3084, -1.0346]],
[[ 0.6628, 0.7274, 0.4928, 0.3759],
[ 0.1345, -1.0521, -0.0848, 1.9390],
[ 0.1522, 0.1753, 0.3924, 0.3414]]])
>>> s.scatter_(1,index,t)
tensor([[[ 0.4599, 0.3912, 0.7227, 0.9956],
[ 0.7050, 0.9624, 0.3776, 0.6071],
[ 0.2600, 0.1659, 1.9888, -1.5446]],
[[ 0.8634, 0.9152, 0.4526, 0.8258],
[ 0.9987, 0.5976, 0.5211, 0.9859],
[ 1.2400, 1.1940, -2.3084, -1.0346]],
[[ 0.4119, 0.8666, 0.6350, 0.9806],
[ 0.1345, -1.0521, -0.0848, 1.9390],
[ 0.7178, 0.2165, 0.6278, 0.7487]]])
>>>
参考文献:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_add_