Scikit-learn的K-fold交叉验证类ShuffleSplit、GroupShuffleSplit用法介绍

当样本数据量比较小时,K-fold交叉验证是训练、评价模型时的常用方法,该方法的作用如下:

  • 交叉验证用于评估模型的预测性能,尤其是训练好的模型在新数据上的表现,可以在一定程度上减小过拟合
  • 交叉验证可以从有限的数据中获取尽可能多的有效信息

Scikit-learn(以下简称sklearn)是基于Numpy、Scipy的开源的Python机器学习库,提供了大量用于数据挖掘和分析的工具,包括数据预处理、交叉验证、算法与可视化算法等一系列接口。
本文介绍sklearn的可用于K-fold交叉验证的集合划分类ShuffleSplit、GroupShuffleSplit的用法。


ShuffleSplit

sklearn.model_selection.ShuffleSplit类用于将样本集合随机“打散”后划分为训练集、测试集(可理解为验证集,下同),类申明如下:

class sklearn.model_selection.ShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=None)

参数:

  • n_splits:int, 划分训练集、测试集的次数,默认为10
  • test_size:float, int, None, default=0.1; 测试集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示测试集占总样本的比例;该值为整型值时,表示具体的测试集样本数量;train_size不设定具体数值时,该值取默认值0.1,train_size设定具体数值时,test_size取剩余部分
  • train_size:float, int, None; 训练集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示训练集占总样本的比例;该值为整型值时,表示具体的训练集样本数量;该值为None(默认值)时,训练集取总体样本除去测试集的部分
  • random_state:int, RandomState instance or None;随机种子值,默认为None

ShuffleSplit类方法包括get_n_splits、split,前者用于返回划分训练集、测试集的次数,后者申明如下:

split(X, y=None, groups=None)

参数:

  • X:array-like, shape (n_samples, n_features);样本特征集合
  • y:array-like, shape (n_samples,);样本标记集合,该值设置时需与X的样本数量(n_samples)一致
  • groups:该参数在此处不生效
  • 返回值:包含训练集、测试集索引值的迭代器

ShuffleSplit应用举例

抽取2016年Kaggle竞赛——State Farm Distracted Driver Detection的10条数据作为样本,如下图所示:
Scikit-learn的K-fold交叉验证类ShuffleSplit、GroupShuffleSplit用法介绍_第1张图片

上图最左侧为自动生成的索引,subject表示驾驶者编号,classname表示图片分类,c0、c1、……、c9表示不同的图片类别,img为图片名称。在实施训练时,以img作为样本特征,以classname作为样本标记。
对于上述样本,使用ShuffleSplit划分5-fold 5次交叉验证的训练集、测试集,代码如下:

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],
                         'classname':['c5', 'c0', 'c1', 'c5', 'c0','c0','c1','c1','c2','c6'],
                         'img':['img_41179.jpg','img_50749.jpg', 'img_53609.jpg','img_52213.jpg','img_72495.jpg',
                                'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})
x_train_names_all = np.array(sample['img'])
y_train_labels_all = np.array(sample['classname'])

rs = ShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
n_fold = 1
for train_indices, test_indices in rs.split(sample):
    print('fold {}/5......'.format(n_fold))
    print("train_indices:", train_indices)
    x_train = x_train_names_all[train_indices, ...]
    print("x_train_names:\n", x_train)
    y_train = y_train_labels_all[train_indices, ...]
    print("y_train:\n", y_train)

    print("test_indices:", test_indices)
    x_test = x_train_names_all[test_indices, ...]
    print("x_test:\n", x_test)
    y_test = y_train_labels_all[test_indices, ...]
    print("y_test:\n", y_test)
    n_fold += 1

输出结果:
Scikit-learn的K-fold交叉验证类ShuffleSplit、GroupShuffleSplit用法介绍_第2张图片

从结果可以看出,依据n_splits值对样本集合划分5次,每次划分时打乱样本,依据比例重新划分。
说明如下:

  • 划分参数在ShuffleSplit类初始化时设置
  • 本例在实例化ShuffleSplit类时,test_size设置为0.2(5-fold),测试集样本数量为2,训练集样本数量为8
  • split函数返回的是包含该训练集、测试集索引的迭代器

GroupShuffleSplit

sklearn.model_selection.GroupShuffleSplit作用与ShuffleSplit相同,不同之处在于GroupShuffleSplit先将待划分的样本集分组,再按照分组划分训练集、测试集。
GroupShuffleSplit类的申明如下:

class sklearn.model_selection.GroupShuffleSplit(n_splits=5, test_size=’default’, train_size=None, random_state=None)

参数个数及含义同ShuffleSplit,只是默认值有所不同:

  • n_splits:int, 划分训练集、测试集的次数,默认为5
  • test_size:float, int, None, default=0.1; 测试集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示测试集占总样本的比例;该值为整型值时,表示具体的测试集样本数量;train_size不设定具体数值时,该值取默认值0.2,train_size设定具体数值时,test_size取剩余部分
  • train_size:float, int, None; 训练集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示训练集占总样本的比例;该值为整型值时,表示具体的训练集样本数量;该值为None(默认值)时,训练集取总体样本除去测试集的部分
  • random_state:int, RandomState instance or None;随机种子值,默认为None

GroupShuffleSplit类的get_n_splits、split方法与ShuffleSplit类的同名方法类似,唯一的不同之处在于split方法的groups参数在此处生效,用于指定分组依据。

GroupShuffleSplit应用举例

在State Farm Distracted Driver Detection数据集上进行训练时,为减小驾驶者服饰、面貌等特征对分类模型泛化能力的不利影响,需要按照驾驶者编号对样本集进行划分,此时用到GroupShuffleSplit类并在split方法中设置groups参数,对样本进行4-fold 4次交叉验证划分的示例代码如下:

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupShuffleSplit
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],
                         'classname':['c5', 'c0', 'c1', 'c5', 'c0','c0','c1','c1','c2','c6'],
                         'img':['img_41179.jpg','img_50749.jpg', 'img_53609.jpg','img_52213.jpg','img_72495.jpg',
                                'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})
x_train_names_all = np.array(sample['img'])
y_train_labels_all = np.array(sample['classname'])
driver_ids = sample['subject']
_, driver_indices = np.unique(np.array(driver_ids), return_inverse=True)
n_fold = 1
rs = GroupShuffleSplit(n_splits=4, test_size=0.25, random_state=0)
for train_indices, test_indices in rs.split(x_train_names_all, y_train_labels_all, groups=driver_indices):
    print('fold {}/4......'.format(n_fold))
    print("train_indices:", train_indices)
    x_train = x_train_names_all[train_indices, ...]
    print("x_train_names:\n", x_train)
    y_train = y_train_labels_all[train_indices, ...]
    print("y_train:\n", y_train)

    print("test_indices:", test_indices)
    x_test = x_train_names_all[test_indices, ...]
    print("x_test:\n", x_test)
    y_test = y_train_labels_all[test_indices, ...]
    print("y_test:\n", y_test)
    n_fold += 1

输出结果:
Scikit-learn的K-fold交叉验证类ShuffleSplit、GroupShuffleSplit用法介绍_第3张图片
回顾一下样本,10个样本来自4个驾驶者,其中p012有2个样本、p014有3个样本、p024有4个样本、p081有1个样本,驾驶者编号与索引的对应关系如下:

驾驶者编号 p012 p012 p014 p014 p014 p024 p024 p024 p024 p081
索引值 0 1 2 3 4 5 6 7 8 9

即p012包含索引0、1,p014包含索引2、3、4,以此类推。
在GroupShuffleSplit类初始化时设置test_size=0.25,在split方法中设置groups参数、指定分类依据为驾驶者编号,即对4个驾驶者进行划分,测试集为1个驾驶者对应的数据,其余驾驶者对应的数据作为训练集。
从输出结果可以看出,第1次划分时,“p012、p014、p081”为训练集,“p024”为测试集;第2次划分时,“p014、p024、p081”为训练集,“p012”为测试集;第3次划分时,“p012、p014、p024”为训练集,“p081”为测试集;第4次划分时,“p012、p024、p081”为训练集,“p014”为测试集。
GroupShuffleSplit用法与ShuffleSplit类似,需要说明的是,split方法的groups参数需设置为组标记,本例中先使用np.unique函数获取旧数组元素在去重后得到的新数组中的位置,以此来作为组标记,其值为:[0 0 1 1 1 2 2 2 2 3],再将其设置为groups的实参。

你可能感兴趣的:(python及其库)