之前写过一篇关于在scikit-learn工具包中,可视化estimator分类模型分类结果的confusion matrix混淆矩阵可视化的方法,具体可以参考看这里,看这里。今天这篇介绍一下如何使用scikit-learn工具中提供的相关方法,可视化其他任意框架(比如深度学习框架)的分类模型预测结果的混淆矩阵。
下面先说一下几个关键步骤:
1、确定类别列表,类别列表和one-hot的编码顺序一致,这里使用cifar-10的类别列表作为演示的例子。
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
2、准备好样本的真实label,这里我手动构造一个1000个样本的label,每一类100个。
# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))
3、准备好样本的预测label,这里我也手动构造这1000个样本的预测label,构造时才用了一点规则,构造出来的预测结果保证从第0类到第9类的预测准确率是逐渐降低的。
# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
# 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
# 这样生成的预测准确率从0到9逐渐递减
pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))
4、计算真是label和预测label的混淆矩阵,直接调用scikit-learn中的confusion_matrix方法
# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))
5、混淆矩阵可视化,在scikit-learn工具中有一个plot_confusion_matrix方法可以可视化sklearn训练的模型estimator的混淆矩阵,具体参数如下:
但是,现在的问题是我们使用的是别的框架训练的模型,也就没有这个estimator参数可以供sklearn使用,怎么办?
我们看一下plot_confusion_matrix函数的代码可以发现,他其实内部调用了以下方法:
那么,我们也仿照这个调用方式来写一下试试,代码如下:
# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
include_values=True, # 混淆矩阵每个单元格上显示具体数值
cmap="viridis", # 不清楚啥意思,没研究,使用的sklearn中的默认值
ax=None, # 同上
xticks_rotation="horizontal", # 同上
values_format="d" # 显示的数值格式
)
6、将以上代码整合一下,输入数据的真实label和预测label,就可以可视化混淆矩阵了,并且不仅局限于评估scikit-learn的estimator,可以适用于所有框架的输出结果,完整代码如下:
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))
# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
# 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
# 这样生成的预测准确率从0到9逐渐递减
pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))
# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))
# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
include_values=True, # 混淆矩阵每个单元格上显示具体数值
cmap="viridis", # 不清楚啥意思,没研究,使用的sklearn中的默认值
ax=None, # 同上
xticks_rotation="horizontal", # 同上
values_format="d" # 显示的数值格式
)
plt.show()
7、混淆矩阵的可视化结果
上图中的可视化结果符合我们在生成预测label标签时使用的规则,就是对于每个类别 i 的预测结果是0-i之间的随机值,这样的话,每个类别的预测误差只会出现在类别编号比它小的部分,也就是上图中展示的下三角矩阵。
在混淆矩阵中,横轴上的标签标示样本的预测label,纵轴上的标签标示样本的实际label。所以,对角线上的数字表示预测label和真是label一致的数量,也就是预测正确的数量。对于其他位置的数字就表示预测错误的,举个例子,比如第2行、第1列,也就是对应着(airplane, automobile)位置的数字51,表示有51个真实label为automobile的样本被预测为了airplane。
通过可视化的混淆矩阵,模型的误差,以及效果分类不好的类别,以及为什么不好,以及容易和哪个类之间出现误识别就一目了然了。
参考:https://blog.csdn.net/cxx654/article/details/107296343