Python 多分类求MAP ValueError: multiclass format is not supported

问题描述

训练多分类模型时,为了观察训练的效果,使用average_precision_score()函数求准确度MAP时报错:

average_precision_score(signs.flatten(),results.flatten())

raise ValueError(“{0} format is not supported”.format(y_type))
ValueError: multiclass format is not supported

原因分析:

之前的二分类任务使用该函数没有问题,想到这次的任务是多分类,该函数应该只可以处理二分类问题,也可以把多分类转变成二分类问题,输出的标签正确与否设置为0和1。

但是感觉那样比较麻烦,不如自己写代码求准确度:

解决方案:

思路:求网络输出最大值的下标,即预测的标签
然后判断预测标签和真实标签,相同则统计+1,
最后根据统计值/训练样本数量 求得准确度

optimizer.zero_grad()
outputs = model(inputs)
ret, prediction = torch.max(outputs,1)
right_num += np.sum(prediction.cpu().numpy() == labels.cpu().numpy())
loss_contrastive = criterion(outputs, labels)
#反向梯度
loss_contrastive.backward()
#更新权重
optimizer.step()
print("Epoch number: {} , Current loss: {:.4f}\n".format(epoch, loss_contrastive.item()))
print('Test Accuracy =',right_num/data_size)

OK~!
Python 多分类求MAP ValueError: multiclass format is not supported_第1张图片

你可能感兴趣的:(python,深度学习)