Pytorch 怎么样把labels转为one-hot(独热编码)的形式

直接上代码:

>>>v=torch.Tensor([[1],[2],[3]])
>>> v
tensor([[1.],
        [2.],
        [3.]])
>>> v.size(0)
3
>>> n=v.size(0)
>>> one_hot = torch.zeros(n,10).long()
>>> one_hot.scatter_(dim=1, index=v.long(), src=torch.ones(n, 10).long())
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
>>> one_hot
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])

搞定,注意里面的数据类型。

你可能感兴趣的:(Pytorch 怎么样把labels转为one-hot(独热编码)的形式)