Pytorch scatter_()用法

import torch

'''
A.scatter_(dim, index, B) # 基本用法, tensor A 被就地scatter到 tensor B
源tensor的每个元素,都按照 index 被scatter(可以理解为填充)到目标tensor中。
B 为源tensor,A为目标tensor。

dim 和 index:两个参数是配套的;
index和源tensor维度一致(可以为空,代表不改变目标tensor),对于n-D tensor,dim可以为0~N-1。
index为几,就把对应位置的元素放入目标tensor的第几行;

reduce参数: 
    默认是None,直接覆盖
    multiply: src元素 * target元素
    add:src元素 + target元素
    对于全0矩阵,None和add效果一致;对于全1矩阵,None和multiply效果一致。
'''
a = torch.randn(2, 3)  # 源tensor
print(a)
b = torch.zeros(2, 3).scatter_(dim=1, index=torch.tensor([[1, 2], [0, 1]]), src=a)
print(b)
'''
上例结果:
tensor([[-0.5172,  0.0915, -1.9869],
        [-0.1619,  1.3641,  0.1983]])
tensor([[ 0.0000, -0.5172,  0.0915],
        [-0.1619,  1.3641,  0.0000]])
'''
c = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a)
print(c)
'''
上例结果:
a:    tensor([[ 0.2210, -1.2891,  1.1144],
              [-0.3524,  0.1736,  2.0364]])
c:    tensor([[-0.3524, -1.2891,  0.0000],
              [ 0.2210,  0.1736,  0.0000]])
'''
d = torch.ones(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="multiply")
# print(d)
'''
tensor([[-8.7126e-01,  1.3744e+00, -5.1777e-04],
        [-1.6414e+00,  1.1157e+00, -1.9982e+00]])
tensor([[-1.6414,  1.3744,  1.0000],
        [-0.8713,  1.1157,  1.0000]])
'''
e = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="add")
print(e)
'''
tensor([[-0.7597,  1.3491, -0.2875],
        [ 1.5010, -1.6951,  2.6675]])
tensor([[ 1.5010,  1.3491,  0.0000],
        [-0.7597, -1.6951,  0.0000]])
'''

参考:

https://zhuanlan.zhihu.com/p/339043454 

你可能感兴趣的:(机器学习,深度学习)