更新tensor指定位置的值 pytorch scatter_add_和scatter_

更新tensor指定位置的值可以使用tensor. scatter_add_(dimindexsrc)  //把src中的值加到指定的tensor上

tensor.scatter_(dim, index,src) //直接用src的值去替换原来tensor中的值。

官网上给出了一个例子:

>>> x = torch.rand(2, 5)
>>> x
tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
        [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
>>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
        [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
        [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

实际上最有用的是这个说明:

For a 3-D tensor, self is updated as:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

我的情况是想用一个tensor去更新矩阵中对应的行。所以我使用的方法如下:

s = torch.tensor([[0.0,0.0,0.0],
                  [1.0,1.0,1.0],
                  [2.0,2.0,2.0],
                  [3.0,3.0,3.0]])

x = torch.tensor([[0.5,0.6,0.7],
                  [0.8,0.9,0.4]])

>>> s.scatter_add_(0, torch.tensor([[2,2,2],[3,3,3]]),x)
tensor([[0.0000, 0.0000, 0.0000],
        [1.0000, 1.0000, 1.0000],
        [2.5000, 2.6000, 2.7000],
        [4.3000, 3.9000, 3.4000]])

对于3维的tensor,其直接更新tensor的值的方法如下:

>>> t = torch.rand(3,2,4)
>>> s = torch.randn(3,3,4)
>>> index = torch.tensor([[[0,0,0,0],[1,1,1,1]],[[1,1,1,1],[0,0,0,0]],[[0,0,0,0],[2,2,2,2]]]) 
#我这里想更新的是第一个tensor矩阵的第0,第1行,第二个tensor矩阵的第1,第0行,第三个tensor矩阵的第0和第2行。
>>> t
tensor([[[0.4599, 0.3912, 0.7227, 0.9956],
         [0.7050, 0.9624, 0.3776, 0.6071]],

        [[0.9987, 0.5976, 0.5211, 0.9859],
         [0.8634, 0.9152, 0.4526, 0.8258]],

        [[0.4119, 0.8666, 0.6350, 0.9806],
         [0.7178, 0.2165, 0.6278, 0.7487]]])
>>> s
tensor([[[ 0.1415,  0.4588,  0.1907,  0.7518],
         [ 0.4902,  0.1167,  0.5882,  0.2648],
         [ 0.2600,  0.1659,  1.9888, -1.5446]],

        [[ 0.5899,  0.9541,  0.4002,  0.5883],
         [ 0.1144,  0.3988,  0.2115,  0.5314],
         [ 1.2400,  1.1940, -2.3084, -1.0346]],

        [[ 0.6628,  0.7274,  0.4928,  0.3759],
         [ 0.1345, -1.0521, -0.0848,  1.9390],
         [ 0.1522,  0.1753,  0.3924,  0.3414]]])
>>> s.scatter_(1,index,t) 
tensor([[[ 0.4599,  0.3912,  0.7227,  0.9956],
         [ 0.7050,  0.9624,  0.3776,  0.6071],
         [ 0.2600,  0.1659,  1.9888, -1.5446]],

        [[ 0.8634,  0.9152,  0.4526,  0.8258],
         [ 0.9987,  0.5976,  0.5211,  0.9859],
         [ 1.2400,  1.1940, -2.3084, -1.0346]],

        [[ 0.4119,  0.8666,  0.6350,  0.9806],
         [ 0.1345, -1.0521, -0.0848,  1.9390],
         [ 0.7178,  0.2165,  0.6278,  0.7487]]])
>>>

参考文献:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_add_

你可能感兴趣的:(pytorch基础)