凌云时刻 · 技术
导读:机器学习算法中有一个重要环节就是评判算法的好坏,我们在之间的笔记中讲过多种评价回归算法的评测标准,比如均方误差(MSE)、均方根误差(RMSE)、平均绝对误差(MAE)、 (R Squared)。但是在分类问题中我们一直使用分类准确度这一个指标,也就是预测对分类的样本数量除以总预测样本数量。但是这个方法存在很大的一个缺陷,所以这篇笔记主要介绍评价分类问题的方式方法。
作者 | 计缘
来源 | 凌云时刻(微信号:linuxpk)
精准率和召回率之间的平衡
在上一篇笔记中,我们了解了逻辑回归的决策边界,比如在二分类问题中,决策边界公式为:
当 大于0时,我们认为分类是1,当小于0时,我们认为分类为0。
如上图所示,黑色直线表示 ,橘黄色直线所在位置表示区分类别为1还是0的分界点,既大于0是蓝色点类型,小于0是红色点类型。那如果我们让 不等于0,而等于一个阀值threshold呢?
那上面的图就会是下面这样:
从上面的图看,threshold是大于0的,这样就相当于调整了区别分类的分界点位置。那么会影响到什么呢?
从上图可以看到,当threshold为0时,示例中的精准率是0.86,召回率是0.75。
当调整threshold大于0后,示例中的精准率是1,召回率是0.38。
当调整threshold小于0后,示例中的精准率是0.7,召回率是0.88。
从这三种情况可以看出,精准率和召回率是互相牵制的,精准率高了,召回率就低。召回率高,精准率就低。所以threshold就又是一个超参数,用来调节使精准率和召回率达到平衡。
通过程序验证精准率和召回率的平衡关系
我们还是使用手写数据的样本数据来验证:
|
我们如何设置threshold呢,其实Scikit Learn中的逻辑回归提供了一个获取评判分数的函数,也就是上图中黑色直线的Score值:
|
Scikit Learn中的confusion_matrix
、precision_score
、recall_score
函数都是基于threshold为0计算的,也就是判断decision_score
中的所有值,如果大于0就分类为1,如果小于0就分类为0。那我们现在将threshold调大一点,比如将5作为区分1和0的分界点,那么我们的预测值就可以这样求:
|
然后我们再来看看精准率和召回率:
|
再将threshold调小看看:
|
通过代码我们可以很明显的看到调节threshold后,精准率和召回率的变化。
PR曲线
通过上一小节我们知道精准率和召回率是相互牵制的,我也认识了一个新的超参数threshold,通过它能调节精准率和召回率。那么我们如何找到一个平衡点,使得精准率和召回率都在一个比较好的水平,换句话说也就是如何找到好的超参数threshold。
这一小节就介绍一个工具,帮助我们更好的找到这个超参数,这就是PR曲线(Precision-Recall曲线)。我们直接来看看Scikit Learn中提供的函数:
|
上图的横轴是threshold值,蓝色曲线是精准率,黄色曲线是召回率,他们相交点的threshold值,就是PR达到平衡的点。
|
上图中,横轴是精准率,纵轴是召回率。这个图反应了PR的总体趋势。通过这个PR曲线我们除了可以判断选择最优的threshold值,还可以判断不同模型的好坏程度。
比如上图中的模型A和模型B可以是通过不同的算法训练的出的模型,也可以是同一个算法,通过不同超参数组合训练出的模型。显然模型B要比模型A好,因为模型B无论是精准率还是召回率都要比模型A的高。
ROC曲线
这一小节我们来看一个新的指标,ROC曲线,既接收者操作特征曲线,是Receiver Operation Characteristic Curve缩写,最早出现在信号检测理论中,后来被广泛应用在不同领域。在机器学习中,ROC用来描述分类模型的TPR和FPR之间的关心,从而确定分类模型的好坏。
FPR和TPR
FPR和TPR同样是基于混淆矩阵而来的,FPR的公式为:
TPR的公式为:
可以看到TPR其实就是Recall指标,而FPR是和TPR相反的指标。下面我们使用Scikit Learn中封装的方法来看看手写数据的TPR、FPR和ROC曲线:
|
从ROC曲线图可以看出,随着FPR的增大,TPR也是随之增大的。我们通过观察这根曲线下的面积大小来判断分类模型的好坏程度,面积越大,说明分类模型越好。Scikit Learn中也提供了计算这个面积的函数:
|
ROC曲线和PR曲线有一个不同之处是,ROC曲线对极度有偏的数据是不敏感的。所以如果样本数据有极度有偏的情况时,通常还是主要使用PR曲线来判断模型的好坏,ROC曲线辅助判断。
多分类问题中的混淆矩阵
我们之前讲的混淆矩阵和精准、召回率都是在二分类问题的前提下。这篇笔记的最后来看看多分类问题中的混淆矩阵。我们同样使用手写数字数据,但这次不再对数据做极度有偏处理了:
|
Scikit Learn 的precision_score
方法有一个average
参数,默认值为binary
,既默认计算二分类问题。如果要计算多分类问题,需要将average
参数设置为micro
:
|
下面来看看这个手写数字十分类问题的混淆矩阵:
|
看多分类问题的混淆矩阵和二分类问题的混淆矩阵方法一样,同样行表示真值,列表示预测值。从上面的结果可看到,混淆矩阵的对角线数值最大,这个对角线就是真值和预测值相同的TP值。我们将这个多分类混淆矩阵通过Matplotlib的matshow
方法绘制出来,直观的看一下:
|
上面这个图可以很清晰的看到TP值,但是我们希望能从图上直观的分析问题,既这个模型预测错误的数据。下面我们将混淆矩阵做一下转换,求出错误矩阵,既FP值矩阵:
|
上图中,颜色约亮的格子表示预测错误的数量越多,比如左上角那个白色的格子就表示真值为3,但是有不少样本数据被预测成了8。左下角的白色格子表示真值为8,但是有不少样本数据被预测成了1。所以从这个错误矩阵上可以很好的分析出具体的预测错误点,从而根据这些信息调整分类模型或者样本数据。
END
往期精彩文章回顾
机器学习笔记(二十三):算法精准率、召回率
机器学习笔记(二十二):逻辑回归中使用模型正则化
机器学习笔记(二十一):决策边界
机器学习笔记(二十):逻辑回归(2)
机器学习笔记(十九):逻辑回归
机器学习笔记(十八):模型正则化
机器学习笔记(十七):交叉验证
机器学习笔记(十六):多项式回归、拟合程度、模型泛化
机器学习笔记(十五):人脸识别
机器学习笔记(十四):主成分分析法(PCA)(2)
长按扫描二维码关注凌云时刻
每日收获前沿技术与科技洞见