留出法(hold-out):直接将数据切分成三个互斥的部分,即训练集、测试集和验证集。在训练集上训练模型,在测试集上选择模型,最后在测试集上评估泛化误差。数据集的划分要尽量保持数据分布的一致性,如在分类任务中至少要保持样本的类别比例相似,此时可以采用分层采样。
浮点数:介于0.0到1.0,代表测试集占原始数据集的比例。
整数:代表测试集的大小。
None:代表测试集大小就是原始数据集大小减去训练集大小。如果训练集大小也指定为None,则test_size设为0.25train_size:一个浮点数,整数或None,指定训练集的大小。
浮点数:介于0.0到1.0,代表训练集占原始数据集的比例。
整数:代表训练集的大小。
None:代表训练集大小就是原始数据集大小减去测试集大小。random_state:一个整数,或者一个RandomState实例,或者None。
如果为整数,则它指定了随机数生成器的种子。
如果为RandomState实例,则指定了随机数生成器。
如果为None,则使用默认的随机数生成器。stratify:一个数据或者None。如果它不是None,则原始数据会分层采样,采样的标记数据由该参数指定。 返回值:一个列表,依次给出一个或者多个数据集的划分的结果。每个数据集都划分为两部分:训练集和测试集。
生成数据
from sklearn.model_selection import train_test_split
import numpy as np
X = np.random.rand(8,4)
y = [1,1,0,0,1,1,0,0]
X
array([[ 0.57182586, 0.34344789, 0.62648921, 0.08838991],
[ 0.23236396, 0.45493656, 0.12884294, 0.68522353],
[ 0.19012725, 0.78536539, 0.66665145, 0.33146112],
[ 0.56584231, 0.32945912, 0.22809843, 0.79332783],
[ 0.9836845 , 0.82029146, 0.12332923, 0.93058032],
[ 0.01305442, 0.35052673, 0.40793758, 0.95430386],
[ 0.81467068, 0.8397317 , 0.11915037, 0.00317844],
[ 0.01244749, 0.24385553, 0.77887998, 0.33716389]])
进行划分
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.4,random_state=0)
print("X_train=",X_train)
print("X_test=",X_test)
print("y_train=",y_train)
print("y_test=",y_test)
X_train= [[ 0.56584231 0.32945912 0.22809843 0.79332783]
[ 0.57182586 0.34344789 0.62648921 0.08838991]
[ 0.01305442 0.35052673 0.40793758 0.95430386]
[ 0.9836845 0.82029146 0.12332923 0.93058032]]
X_test= [[ 0.81467068 0.8397317 0.11915037 0.00317844]
[ 0.19012725 0.78536539 0.66665145 0.33146112]
[ 0.23236396 0.45493656 0.12884294 0.68522353]
[ 0.01244749 0.24385553 0.77887998 0.33716389]]
y_train= [0, 1, 1, 1]
y_test= [0, 0, 1, 0]
分层采样
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.4,random_state=0,stratify=y)
print("X_train=",X_train)
print("X_test=",X_test)
print("y_train=",y_train)
print("y_test=",y_test)
X_train= [[ 0.9836845 0.82029146 0.12332923 0.93058032]
[ 0.81467068 0.8397317 0.11915037 0.00317844]
[ 0.57182586 0.34344789 0.62648921 0.08838991]
[ 0.01244749 0.24385553 0.77887998 0.33716389]]
X_test= [[ 0.19012725 0.78536539 0.66665145 0.33146112]
[ 0.56584231 0.32945912 0.22809843 0.79332783]
[ 0.23236396 0.45493656 0.12884294 0.68522353]
[ 0.01305442 0.35052673 0.40793758 0.95430386]]
y_train= [1, 0, 1, 0]
y_test= [0, 0, 1, 1]