pytorch scatter_

scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会

PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改

scatter(dim, index, src) 的参数有 3 个

dim:沿着哪个维度进行索引
index:用来 scatter 的元素索引
src:用来 scatter 的源元素,可以是一个标量或一个张量
这个 scatter 可以理解成放置元素或者修改元素

简单说就是通过一个张量 src 来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定。

import torch

logits = torch.rand([2,4])
print(logits)
'''
tensor([[0.0966, 0.4437, 0.7198, 0.2911],
        [0.6818, 0.9492, 0.9412, 0.7439]])
'''
logits_1 = torch.max(logits, 1)
print(logits_1)
'''
(tensor([0.7198, 0.9492]), tensor([2, 1]))
'''

logits_2 = torch.max(logits, 1)[1]
print(logits_2)
'''
tensor([2, 1])
'''
logits_3 = logits_2.view(-1,1)
print(logits_3)
'''
tensor([[2],
        [1]])
'''
one_hots = torch.zeros(logits.size())

one_hots.scatter_(1, logits_3, 1)
print(one_hots)

'''
tensor([[0., 0., 1., 0.],
        [0., 1., 0., 0.]])
'''

你可能感兴趣的:(pytorch)