在典型的机器学习应用中,为进一步提高模型在预测未知数据的性能,还要对不同的参数设置进行调优和比较,该过程称为模型选择。指的是针对某一特定问题,调整参数以寻求最优超参数的过程。
假设要在10个不同次数的二项式模型之间进行选择:
具体的模型选择方法为:
1. 使用训练集训练出 10 个模型
Train error:
缺点:模型性能的皮评估对训练数据划分为训练及验证子集的方法是敏感的;评价的结果是敏感的;评价的结果会随样本的不同而发生变化。
K折交叉验证中,不重复地随机将训练数据集划分为K个,其中k-1个用于模型的训练,剩余的1个用于测试。重复此过程k次,就得到了k个模型及对模型性能的评价。该种方法对数据划分的敏感性较低。
下图展示了k折交叉验证,k=10,训练数据集被划分为10块,在10次迭代中,每次迭代中都将9块用于训练,剩余的1块用于模型的评估。10块数据集作用于某一分类器,分类器得到的性能评价指标为 Ei,i=1,2,⋯,10 E i , i = 1 , 2 , ⋯ , 10 ,可用来计算模型的估计平均性能 110∑10i=1Ei 1 10 ∑ i = 1 10 E i
K折交叉验证中k的标准值为10,对大多数应用来说都时合理的。如果训练集相对较小,可以增大k值,这样将会有更多的数据用于进行训练,这样性能的评估结果也会得到较小的偏差。但是k值得增加会导致交叉验证算法的时间延长,并使得训练块高度相似,无法发挥交叉验证的效果。如果数据集较大,可以选择较小的k值,降低在不同数据块的重复计算成本,同时训练快比例小但是依然有大量的训练数据。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier # K最近邻(kNN,k-NearestNeighbor)分类算法
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
# 加载iris数据集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
if __name__ == '__main__':
k_scores = []
# 由迭代的方式来计算不同k对模型的影响,并返回交叉验证后的平均准确率
for k in range(1, 31):
# 采用k个最近邻
knn = KNeighborsClassifier(n_neighbors=k)
'''
n_jobs为将交叉验证过程分布到的CPU核心数量,若>1,需要放到"if __name__ == '__main__'"中
'''
scores = cross_val_score(knn, X_train, y_train, cv=10, n_jobs=1)
k_scores.append(scores.mean())
# 可视化数据
plt.plot(range(1, 31), k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()
分层交叉验证中,类别比例在每个分块中保持一致,使得每个分块中的类别比例与训练数据集的整体比例一致,在sklearn中使用StratifiedKFold,其中cross_val_score默认也采用分层k折交叉验证
from sklearn.model_selection import StratifiedKFold
import numpy as np
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [2, 3], [4, 5]])
y = np.array([0, 0, 0, 1, 1, 1])
skf = StratifiedKFold(n_splits=3)
skf.get_n_splits(X, y)
for train_index, valid_index in skf.split(X, y):
print('Train:', train_index, 'Valid:', valid_index)
Train: [1 2 4 5] Valid: [0 3]
Train: [0 2 3 5] Valid: [1 4]
Train: [0 1 3 4] Valid: [2 5]