多分类任务的混淆矩阵处理

多分类任务的混淆矩阵处理

在多分类任务中,不适合使用PR曲线和ROC曲线来进行指标评价,但我们仍可以通过混淆矩阵来进行处理。可以通过matplotlib的matshow()函数,直观地展示分类结果的好坏。

先使用cross_val_predict得出各个分类值的分数

 y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv= 3 )

再使用confusion_matrix()得出最终的混淆矩阵

conf_mx = confusion_matrix(y_train, y_train_pred)

然后使用 Matplotlib 的 matshow() 函数,将混淆矩阵以图像的方式呈现

plt.matshow(conf_mx, cmap=plt.cm.gray)

如下图所示,行代表了实际的类别,列代表了预测的结果,从图中可看出大致都在正对角线上,说明分类结果还不错。
多分类任务的混淆矩阵处理_第1张图片
但是我们应该关注仅包含误差数据的图像呈现,所以将混淆矩阵的每一个值除以相应类别的图片的总数目。这样子,你可以比较错误率,而不是绝对的错误数(这对大的类别不公平)

row_sums = conf_mx.sum(axis= 1 , keepdims= True )
norm_conf_mx = conf_mx / row_sums

然后用 0 来填充对角线(使正确的分类不可见),这样子就只保留了被错误分类的数据。

np.fill_diagonal(norm_conf_mx,  0 )
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)

如下图所示,8,9列比较亮,说明有很多都被错误地分到了8,9类中去。相似的,第 8、9 行也相当亮,也就是说8,9类也经常被误以为是其他类别。
多分类任务的混淆矩阵处理_第2张图片
所以通过这个混淆矩阵图像,分析混淆矩阵通常可以给你提供深刻的见解去改善你的分类器。回顾这幅图,看样子你应该努力改善分类器在类别8 和类别 9 上的表现,和纠正 3/5 的混淆。

举例子,你可以尝试去收集更多的数据,或者你可以构造新的、有助于分类器的特征。举例子,写一个算法去数闭合的环(比如,数字 8 有两个环,数字 6 有一个, 5 没有)。又或者你可以预处理图片(比如,使用 Scikit-Learn,Pillow, OpenCV)去构造一个模式,比如闭合的环。

你可能感兴趣的:(深度学习)