k折交叉验证 k-fold cross-validation

文章目录

    • k折交叉验证
    • k值的确定
    • 实例
    • 使用scikit-learn进行交叉验证

交叉验证是用来评估机器学习方法的有效性的统计学方法,可以使用有限的样本数量来评估模型对于验证集或测试集数据的效果。

k折交叉验证

参数 k k k表示,将给定的样本数据分割成 k k k组。 k = 10 k=10 k=10时,称为10折交叉验证。

流程如下

  1. 将数据集随机打乱。Shuffle the dataset randomly.
  2. 将数据集随机分割为 k k k组。
  3. 对于每一个组,进行如下操作:
    1. 将这一个组的数据当做测试集
    2. 剩余的 k − 1 k-1 k1个组的数据当做训练集
    3. 使用测试集训练模型,并在测试集上进行评测
    4. 保留评测的分数,抛弃模型
  4. 使用 k k k次评测分数的平均值来总结模型的性能,有时也会统计 k k k次评测分数的方差

总结:每一个样本都做了一次测试集,做了 k − 1 k-1 k1次训练集。

k值的确定

k k k值必须仔细确定。不合适的 k k k值会导致不能准确评估模型的性能,会得出high variance或high bias的结果。There is a bias-variance trade-off associated with the choice of k in k-fold cross-validation

选择 k k k值的几点策略:

  1. 数据的代表性: k k k值必须使得每一组训练集和测试集中的样本数量都足够大,使其在统计学意义上可以代表更广泛的数据。
  2. k = 10 k=10 k=10:这是一个经过广泛的实验得到的一个经验值。所得的结果会有较低的偏差和适量的方差 (low bias and modest variance)
  3. k = n k=n k=n:其中 n n n是样本的数量。这样一来,使得一个样本作为测试集,其他样本作为训练集。因此也被称作 leave-one-out 交叉验证 (LOOCV)。

总结: k k k没有固定值,不过通常取值5或10。随着 k k k值的增大,训练集的大小和采样子集之间的差异变小,对模型评估的偏差也会减小。 k = 10 k=10 k=10总是一个较优的选择。

实例

给定一个样本数据集如下:

[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

选定 k = 3 k=3 k=3,则样本分为3组,每组2个数据。

Fold1: [0.5, 0.2]
Fold2: [0.1, 0.3]
Fold3: [0.4, 0.6]

接下来训练三个模型,并使用对应的测试集进行评测。
Model1: Trained on Fold1 + Fold2, Tested on Fold3
Model2: Trained on Fold2 + Fold3, Tested on Fold1
Model3: Trained on Fold1 + Fold3, Tested on Fold2

使用scikit-learn进行交叉验证

使用KFold类:

kfold = KFold(3, True, 1)

其中:3表示 k = 3 k=3 k=3,True表示随机打乱数据集,1是随机数的种子seed。

接下来使用split函数将样本进行分组:

# enumerate splits
for train, test in kfold.split(data):
	print('train: %s, test: %s' % (data[train], data[test]))

其中的train和test都是原始数据在data数组中的索引值。

完整的代码如下:

# scikit-learn k-fold cross-validation
from numpy import array
from sklearn.model_selection import KFold
# data sample
data = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
# prepare cross validation
kfold = KFold(3, True, 1)
# enumerate splits
for train, test in kfold.split(data):
	print('train: %s, test: %s' % (data[train], data[test]))

输出结果为:

# enumerate splits
train: [0.1 0.4 0.5 0.6], test: [0.2 0.3]
train: [0.2 0.3 0.4 0.6], test: [0.1 0.5]
train: [0.1 0.2 0.3 0.5], test: [0.4 0.6]

参考链接:
https://machinelearningmastery.com/k-fold-cross-validation/

你可能感兴趣的:(人工智能/深度学习/机器学习)