Pytorch Error:RuntimeError: Assertion cur_target 0 cur_target n_classes failed

ERROR

使用pytorch的函数 torch.nn.CrossEntropyLoss()计算Loss时报错
或者
loss = criterion(output, target)
报错:

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed

解决方法:

原因一:模型输出与分类数不一致

看模型的输出尺寸分类数差异是否明显,核查代码是否存在错误。
如果没有错误,只是映射维度不对,可以考虑在模型的最后一层加一层FC层,将输出尺寸映射到分类大小

原因二:标签的设置不是从0开始

如果模型的输出尺寸分类数大小相同,看一下标签的设定是否是从0开始的。
如果标签是从1开始设置的,重新设置标签。
在使用CrossEntropyLoss()这个函数进行验证时,标签必须从0开始设置,否则便会报错。

你可能感兴趣的:(python,pytorch)