Pytorch踩坑:CrossEntropyLoss不支持one-hot label 报错:RuntimeError: multi-target not supported

报错:

RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

 

原因:

使用nn.CrossEntropyLoss时,label必须是[0, #classes] 区间的一个数字 ,而不可以是one-hot encoded 目标向量

当你的label如: [1, 0, 0], PyTorch 认为你想给数据赋予多个标签,这是不支持的。

 

解决方法:

把 one-hot-encoded targets换成整数,例如:

[1, 0, 0] --> 0

[0, 1, 0] --> 1

你可能感兴趣的:(Pytorch神经网络,python,pytorch)