记录_nn.functional.one_hot(x, num_classes)
- 作用:将x的元素转化为0/1标签值
- Return.shape: (x.size, num_classes)
- num_classes 为目标类别数,需要>=实际类别数(输入x的元素值的类别)
a = torch.randint(0, 24, (3, 4, 5))
>>>
tensor([[[21, 2, 0, 5, 19],
[10, 6, 10, 1, 5],
[10, 7, 11, 0, 2],
[17, 15, 16, 1, 6]],
[[10, 2, 3, 8, 12],
[10, 5, 18, 13, 6],
[ 3, 14, 10, 16, 6],
[15, 5, 4, 6, 3]],
[[21, 14, 11, 2, 11],
[ 3, 17, 0, 16, 5],
[16, 15, 8, 3, 10],
[21, 13, 15, 10, 14]]])
b = torch.argmin(a, dim=2)
>>>
tensor([[2, 3, 3, 3],
[1, 1, 0, 4],
[3, 2, 3, 3]]) torch.Size([3, 4])
c = torch.nn.functional.one_hot(b, 5)
- 注意此处 5 >= b中索引值(0,1,2,3,4)的总数5
>>>
tensor([[[0, 0, 1, 0, 0], # 2
[0, 0, 0, 1, 0], # 3
[0, 0, 0, 1, 0], # 3
[0, 0, 0, 1, 0]], # 3
[[0, 1, 0, 0, 0], # 1
[0, 1, 0, 0, 0], # 1
[1, 0, 0, 0, 0], # 0
[0, 0, 0, 0, 1]], # 4
[[0, 0, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]]]) torch.Size([3, 4, 5])
- 比如,num_classes = 6,但其实是多余了
c = torch.nn.functional.one_hot(b, 6)
tensor([[[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 0]],
[[0, 1, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 0]]]) torch.Size([3, 4, 6])