Pytorch 多分类结果测试

在模型训练过程中需要对当前的效果进行验证,或者训练结束后需要在测试集上对模型进行测试。比如多分类问题,网络的前向传播的结果是一个概率值Tensor,如果是一个10分类问题,并且batch=4,结果是一个4*10的Tensor,Tensor的每一行表示某张图片分别在10分类下的预测概率值。

Pytorch中的argmax()函数可以返回Tensor中每一行最大值的索引,torch.eq()函数可以比较两个Tensor对应位置处的值是否相等,返回一个Tensor的结果,0表示不相等,1表示相等。

我们可以使用argmax()与torch.eq()进行多分类问题准确率的计算:

Pytorch 多分类结果测试_第1张图片

下面的代码是在训练完成后,在测试集上进行测试准确率的代码片段:

Pytorch 多分类结果测试_第2张图片

相信代码可参考https://blog.csdn.net/weicao1990/article/details/98754647

你可能感兴趣的:(深度学习,pytorch)