import torch ''' A.scatter_(dim, index, B) # 基本用法, tensor A 被就地scatter到 tensor B 源tensor的每个元素,都按照 index 被scatter(可以理解为填充)到目标tensor中。 B 为源tensor,A为目标tensor。 dim 和 index:两个参数是配套的; index和源tensor维度一致(可以为空,代表不改变目标tensor),对于n-D tensor,dim可以为0~N-1。 index为几,就把对应位置的元素放入目标tensor的第几行; reduce参数: 默认是None,直接覆盖 multiply: src元素 * target元素 add:src元素 + target元素 对于全0矩阵,None和add效果一致;对于全1矩阵,None和multiply效果一致。 ''' a = torch.randn(2, 3) # 源tensor print(a) b = torch.zeros(2, 3).scatter_(dim=1, index=torch.tensor([[1, 2], [0, 1]]), src=a) print(b) ''' 上例结果: tensor([[-0.5172, 0.0915, -1.9869], [-0.1619, 1.3641, 0.1983]]) tensor([[ 0.0000, -0.5172, 0.0915], [-0.1619, 1.3641, 0.0000]]) ''' c = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a) print(c) ''' 上例结果: a: tensor([[ 0.2210, -1.2891, 1.1144], [-0.3524, 0.1736, 2.0364]]) c: tensor([[-0.3524, -1.2891, 0.0000], [ 0.2210, 0.1736, 0.0000]]) ''' d = torch.ones(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="multiply") # print(d) ''' tensor([[-8.7126e-01, 1.3744e+00, -5.1777e-04], [-1.6414e+00, 1.1157e+00, -1.9982e+00]]) tensor([[-1.6414, 1.3744, 1.0000], [-0.8713, 1.1157, 1.0000]]) ''' e = torch.zeros(2, 3).scatter_(dim=0, index=torch.tensor([[1, 0], [0, 1]]), src=a, reduce="add") print(e) ''' tensor([[-0.7597, 1.3491, -0.2875], [ 1.5010, -1.6951, 2.6675]]) tensor([[ 1.5010, 1.3491, 0.0000], [-0.7597, -1.6951, 0.0000]]) '''
参考:
https://zhuanlan.zhihu.com/p/339043454