使用scikit-learn中的metrics.plot_confusion_matrix混淆矩阵函数分析分类器的误差来源

在前面的文章中介绍了使用scikit-learn绘制ROC曲线和使用scikit-learn绘制误差学习曲线,通过绘制ROC曲线和误差学习曲线可以让我们知道我们的模型现在整体上做的有多好,可以判断模型的状态是过拟合还是欠拟合,从而确定后续的优化方向。但是绘制学习曲线的方法只能让我们从整体上了解模型的性能,并不能具体展示具体的误差来源。在吴恩达老师的视频中,多次强调误差分析的重要性,就是针对模型处理出错的样本进行重点研究分析,然后选择可能的优化方向。今天这篇短文就来讲一下针对分类问题,如何使用scikit-learn工具进行简单的误差分析。

本示例使用SVM分类器,对手写数字进行分类。

1、加载数据集,并划分训练集和验证集

%matplotlib inline
from sklearn import datasets
from sklearn.svm import SVC
import warnings
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split

digits = datasets.load_digits()
X, y = digits.data, digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)

X_train.shape, X_test.shape

((1257, 64), (540, 64))

2、在训练集上训练一个SVM分类器

svc_clf = SVC(kernel="linear", C=1.0)
svc_clf.fit(X_train, y_train)

3、绘制分类器在测试集上的结果混淆矩阵

from sklearn import metrics

metrics.plot_confusion_matrix(svc_clf, X_test, y_test)

使用scikit-learn中的metrics.plot_confusion_matrix混淆矩阵函数分析分类器的误差来源_第1张图片

混淆矩阵的横坐标表示模型的预测结果,纵坐标表示真实结果。淆矩阵的每行元素表示类别i(行编号)被预测为类别j(列编号)的数量。所以在上图中,除对角线以外的非0数值都是被误分类的样本数量。比如第8行的第6列的数值为2,表示有2个数字8倍误分类程了数字2。通过混淆矩阵可以方便的看出模型在矩阵那些类别之间的误分类比较严重,从而有利于确定下一步的优化方向。

在上图中,出对角线以外的非0元素总和是:1 + 1 + 1 + 1 + 1 + 1 + 2 = 8,那么该分类模型在测试集上的准确率应该是(540 - 8)/540 = 0.985。使用metrics.accuracy函数计算一下看看是否一致:

metrics.accuracy_score(y_test, svc_clf.predict(X_test))

0.9851851851851852。结果与预期一致。

参考:scikit-learn中文翻译

你可能感兴趣的:(机器学习,python编程)