output = torch.Tensor.scatter_(dim, index, src)

dim = 0, 按照数值方向操作;

dim = 1, 按照水平方向操作;

a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b) 

output = torch.Tensor.scatter_(dimindexsrc)

a = torch.rand(2,5)
print(a)
tensor([[0.0548, 0.1293, 0.1842, 0.6538, 0.7267],
        [0.6978, 0.7296, 0.7779, 0.7206, 0.0884]])
b = torch.zeros(3,5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print()
print(b)
tensor([[0.0548, 0.7296, 0.7779, 0.6538, 0.7267],
        [0.0000, 0.1293, 0.0000, 0.7206, 0.0000],
        [0.6978, 0.0000, 0.1842, 0.0000, 0.0884]])

 有代码可知,scatter_函数中dim=0,即按照竖直方向从a中取数据,然后按照竖直方向将从a中取出的数据放入对应列的index行中。例如,取a的第一列第一行元素0.0548放入b中第一列中第index=0行中,即b[0,0]=0.0548;取a的第一列第二行元素0.6978放入b中第一列中第index=2行中,即b[2,0]=0.6978。其他a中的元素放入方式,和上述方法相同。

参考文献:

[1]https://zhuanlan.zhihu.com/p/59346637

你可能感兴趣的:(pytorch)