转载pytorch one-hot编码_GXLiu-CSDN博客
方案一: 使用scatter_
将标签转换为one-hot
import torch
num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = torch.zeros((len(label), num_class)).scatter_(1, label.long().reshape(-1, 1), 1)
print(one_hot)
"""
tensor([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]])
"""
方案二: F.onehot 自动实现
import torch.nn.functional as F
import torch
num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = F.one_hot(label, num_classes=num_class )
print(one_hot)
"""
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0]])
"""