utils专栏不会细讲概念性的内容,偏向实际使用,如有问题,欢迎留言。如果对你有帮助就点个赞哈,也不搞什么粉丝可见有的没的,有帮助点个赞就ok
混淆矩阵 对于混淆矩阵的计算,这个我们直接从sklearn.metrics导入confusion_matrix计算,只需要向其中传递两个参数,一个是y_true,一个是y_pred,就可以直接得到混淆矩阵了:
from sklearn.metrics import confusion_matrix
confMatrix = confusion_matrix(label, pre)
这个是随便拿了个数据集,加载了预训练参数,跑了1轮的混淆矩阵:
准确率、F1和召回率的计算我们直接使用混淆矩阵来计算,混淆矩阵可以帮助我们很好的获得以下每类的数目:
准确率:
精确率:
召回率:
F1:
相关的理论部分这里不过介绍,直接上代码实现,最后返回的是总的精确率、总的召回率、总的F1_score、一个图表可视化和一个几何平均,可以根据自己的需要来调整代码,比如只需要某一类的召回率等等:
注:使用图标可视化精确率、召回率和F1时,请先:
import prettytable
def calculate_prediction_recall(label, pre, classes=None):
"""
计算准确率和召回率:传入预测值及对应的真实标签计算
:param label:标签
:param pre:对应的预测值
:param classes:类别名(None则为数字代替)
:return:
"""
if classes:
classes = list(range(classes))
# print(classes)
confMatrix = confusion_matrix(label, pre)
print(confMatrix)
total_prediction = 0
total_recall = 0
result_table = prettytable.PrettyTable()
class_multi = 1
result_table.field_names = ['Type', 'Prediction(精确率)', 'Recall(召回率)', 'F1_Score']
for i in range(len(confMatrix)):
label_total_sum_col = confMatrix.sum(axis=0)[i]
label_total_sum_row = confMatrix.sum(axis=1)[i]
if label_total_sum_col: # 防止除0
prediction = confMatrix[i][i] / label_total_sum_col
else:
prediction = 0
if label_total_sum_row:
recall = confMatrix[i][i] / label_total_sum_row
else:
recall = 0
if (prediction + recall) != 0:
F1_score = prediction * recall * 2 / (prediction + recall)
else:
F1_score = 0
result_table.add_row([classes[i], np.round(prediction, 3), np.round(recall, 3),
np.round(F1_score, 3)])
total_prediction += prediction
total_recall += recall
class_multi *= prediction
total_prediction = total_prediction / len(confMatrix)
total_recall = total_recall / len(confMatrix)
total_F1_score = total_prediction * total_recall * 2 / (total_prediction + total_recall)
geometric_mean = pow(class_multi, 1 / len(confMatrix))
return total_prediction, total_recall, total_F1_score, result_table, geometric_mean, confMatrix
图标可视化的一个结果展示(注意:图标显示的是每一类的准确率、召回率和F1,函数返回的是总的准确率、召回率和F1,可以根据自己的需要进行修改代码):
没什么好说的,直接上代码:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
混淆矩阵的可视化: 传入混淆矩阵和类别名(或数字代替)
:param cm: 混淆矩阵
:param classes: 类别
:param normalize:
:param title:
:param cmap:
:return:
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig('runs/picture/confMatrix.jpg')
plt.show()
结果展示: