凌云时刻 · 技术
导读:在上一篇笔记中我们又提到了训练数据集和测试数据集,拆分样本数据的这种做法目的就是通过测试数据集判断模型的好坏,如果我们发现训练出的模型产生了过拟合的现象,既在训练数据集上预测评分很好,但是在测试数据集上预测评分不好的情况,那可能就需要重新调整超参数训练模型,以此类推,最终找到一个或一组参数使得模型在测试数据集上的预测评分也很好,也就是训练出的模型泛化能力比较好。
作者 | 计缘
来源 | 凌云时刻(微信号:linuxpk)
验证数据集与交叉验证
那么这种方式会产生一个问题,就是有可能会针对测试数据过拟合,因为每次都是找到参数训练模型,然后看看在测试数据集上的表现如何,这就让我们的模型又被测试数据集左右了,既可以理解为训练出的模型对特定的训练数据集和特定的测试数据集表现都不错,但是再来一种类似的样本数据,表现可能又不尽如人意了。
那么要彻底解决这个问题,就要引入验证数据集的概念,既将样本数据分为三份,训练数据集、验证数据集、测试数据集。
训练数据集和之前的用途一样,是用来训练模型的。
验证数据集的作用和之前的测试数据集一样,是用来验证由训练数据集训练出的模型的好坏程度的,或者说是调整超参数使用的数据集。
此时的测试数据集和之前的作用就不一样了,这里的测试数据集是当训练出的模型在训练数据集和验证数据集上都表现不错的前提下,最终衡量该模型性能的数据集。测试数据集在整个训练模型的过程中是不参与的。
交叉验证(Cross Validation)
我们在验证数据集概念的基础上,再来看看交叉验证。交叉验证其实解决的是随机选取验证数据集的问题,因为如果验证数据集是固定的,那么万一验证数据集过拟合了,那就没有可用的验证数据集了,所以交叉验证提供了随机的、可持续的、客观的模型验证方式。
交叉验证的思路是将训练数据分成若干份,假设分为A、B、C三份,分别将这三份各作为一次验证数据集,其他两份作为训练数据集训练模型,然后将训练出的三个模型评分取均值,将这个均值作为衡量算法训练模型的结果来调整参数,如果平均值不够好,那么再调整参数,再训练出三个模型,以此类推。
实现交叉验证
我们使用KNN算法,用训练数据集和测试数据集方式进行超参数k
和p
的调整查找(KNN的k
,p
两个超参数查阅第二篇学习笔记):
|
从结果看,通过上面的算法,我们找到了最好评分98.6%和对应的k
值3和p
值4。但是需要注意的是,这个结果有可能是对测试数据集过拟合的结果。
下面我们再来看看如何使用交叉验证方法进行超参数的调参:
|
我们直接使用Scikit Learn中提供的交叉验证函数,默认会将训练数据分成三份,所以会有三个模型的评分。然后再修改一下上面调参的算法:
|
可以看到,使用交叉验证法最后搜寻到的最佳k
值是2,最佳p
值是2,然后我们再用搜寻出的这两个超参数来训练模型,然后使用测试数据集来计算评分:
|
最终我们搜寻到的最佳超参数训练出的模型,通过测试数据验证后评分为98.05%,这个评分虽然比之前用训练数据和测试数据搜寻到的最佳评分低一些,但是这个分数不会对验证数据集过拟合,是泛化能力更好的模型。
第三篇笔记中,讲过KNN通过网格搜索搜寻最佳超参数的方法,其实当时GridSearchCV中的CV就是Cross Validation的意思,也就是网格搜索本身就使用的交叉验证的方式搜寻超参数,我们再来回顾一下:
|
可以看到执行fit
函数后,会打印出一句话来,意思就是超参数组合一共有45个,每个组合会将训练数据分为三份,一共会训练出135个模型,最后求出一个泛化能力最好的模型。
|
通过上面的结果可以看到,和我们之前的结果是一致的。
偏差(Bias)与方差(Variance)
在机器学习算法中,模型的好坏有一个统称就是预测结果的误差大小。那么这个误差具体可分为偏差和方差。
上面这幅图有四个靶子,可以很好的诠释方差和偏差的概念。红色靶心就相当于我们目标,灰色弹孔就相当于模型预测的值。我们来解读一下这四幅图:
左上:模型预测的值基本都在目标值上,并且每次预测的都很集中,说明偏差和方差都很小。
左下:模型预测的值虽然每次都很集中,但是整体和目标值差的很远,说明偏差很大,方差比较小。
右上:模型预测的值基本都围绕着目标值,但是每次预测的值之间差距较大,说明偏差较小,方差比较大。
右下:模型预测的值离目标值都很远,并且每次预测的值之间差距也比较打,说明偏差和方差都很大。
通常情况下,我们训练出的模型误差指的是偏差和方差的总和,再加上一些不可避免的误差,比如训练数据本身噪音比较大等。
通常导致偏差的主要原因是对问题本身的假设不正确,比如本身训练数据并没有线性关系,但我们还是使用线性回归去训练模型,那么模型的偏差肯定会很大,也就是欠拟合的情况。
通常导致方差的主要原因是因为我们的模型太过复杂,学习到太多的噪音,比如多项式回归,当degree
参数非常大的时候,也就是过拟合的情况。
非参数学习通常都是高方差算法,比如分类的算法,因为不会对数据进行任何假设。参数学习通常都是高偏差算法,因为会对数据有极强的假设,一旦训练数据有问题,那么就会导致模型整体偏离真实情况。
在使用机器学习解决问题的实践中,通常我们的挑战都是降低模型的方差,一般有以下几种手段:
降低模型复杂度。比如降低多项式回归的degree
参数。
减少数据维度,降噪。比如使用PCA。
增加样本数量。让训练数据足以支撑复杂的模型,从而能计算出合适的参数。
使用交叉验证。避免过拟合情况。
END
往期精彩文章回顾
机器学习笔记(十六):多项式回归、拟合程度、模型泛化
机器学习笔记(十五):人脸识别
机器学习笔记(十四):主成分分析法(PCA)(2)
机器学习笔记(十三):主成分分析法(PCA)
机器学习笔记(十二):随机梯度下降
机器学习笔记(十):梯度下降
机器学习笔记(九):多元线性回归
机器学习笔记(七):线性回归
长按扫描二维码关注凌云时刻
每日收获前沿技术与科技洞见