直接上代码:
>>>v=torch.Tensor([[1],[2],[3]])
>>> v
tensor([[1.],
[2.],
[3.]])
>>> v.size(0)
3
>>> n=v.size(0)
>>> one_hot = torch.zeros(n,10).long()
>>> one_hot.scatter_(dim=1, index=v.long(), src=torch.ones(n, 10).long())
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
>>> one_hot
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])
搞定,注意里面的数据类型。