留一法(Leave-One-Out)是S折交叉验证的一种特殊情况,当S=N时交叉验证便是留一法,其中N为数据集的大小。该方法往往比较准确,但是计算量太大,比如数据集有10万个样本,那么就需要训练10个模型。
给定包含N个样本的数据集 T T ,有放回的采样N次,得到采样集 Ts T s 。数据集 T T 中样本可能在 Ts T s 中多次,也可能不出现在 Ts T s 。一个样本始终不在采样集中出现的概率是 (1−1N)N ( 1 − 1 N ) N 。根据: limN→∞(1−1N)N=1e=0.368 l i m N → ∞ ( 1 − 1 N ) N = 1 e = 0.368 ,因此 T T 中约有63.2%的样本出现在 Ts T s 中。将 Ts T s 用作训练集, T−Ts T − T s 用作测试集。
sklearn.model_selection.LeaveOneOut()
from sklearn.model_selection import LeaveOneOut
import numpy as np
X = np.array([[1,2,3,4],
[11,12,13,14],
[21,22,23,24],
[31,32,33,34]])
y = np.array([1,1,0,0])
loo = LeaveOneOut()
for train_index,test_index in loo.split(X,y):
print("Train Index:",train_index)
print("Test Index:",test_index)
print("X_train:",X[train_index])
print("X_test:",X[test_index])
print("")
Train Index: [1 2 3]
Test Index: [0]
X_train: [[11 12 13 14]
[21 22 23 24]
[31 32 33 34]]
X_test: [[1 2 3 4]]
Train Index: [0 2 3]
Test Index: [1]
X_train: [[ 1 2 3 4]
[21 22 23 24]
[31 32 33 34]]
X_test: [[11 12 13 14]]
Train Index: [0 1 3]
Test Index: [2]
X_train: [[ 1 2 3 4]
[11 12 13 14]
[31 32 33 34]]
X_test: [[21 22 23 24]]
Train Index: [0 1 2]
Test Index: [3]
X_train: [[ 1 2 3 4]
[11 12 13 14]
[21 22 23 24]]
X_test: [[31 32 33 34]]
构造数据
from sklearn.model_selection import LeaveOneOut
import numpy as np
import pandas as pd
import random
data = pd.DataFrame(np.random.rand(10,4),columns=list('ABCD'))
data['y'] = [random.choice([0,1]) for i in range(10)]
print(data)
A B C D y
0 0.704388 0.586751 0.931694 0.992170 0
1 0.887570 0.648561 0.959472 0.573279 0
2 0.403431 0.356454 0.780375 0.987747 1
3 0.793327 0.377768 0.008651 0.583467 1
4 0.493081 0.455021 0.352437 0.971354 1
5 0.706505 0.650936 0.532032 0.791598 1
6 0.414343 0.424478 0.802620 0.584577 1
7 0.017409 0.267225 0.740127 0.050121 1
8 0.654652 0.673884 0.582909 0.428070 0
9 0.515559 0.262651 0.339282 0.394977 1
自助法划分数据集
train = data.sample(frac=1.0,replace=True)
test = data.loc[data.index.difference(train.index)].copy()
print(train)
A B C D y
8 0.654652 0.673884 0.582909 0.428070 0
1 0.887570 0.648561 0.959472 0.573279 0
1 0.887570 0.648561 0.959472 0.573279 0
4 0.493081 0.455021 0.352437 0.971354 1
0 0.704388 0.586751 0.931694 0.992170 0
3 0.793327 0.377768 0.008651 0.583467 1
8 0.654652 0.673884 0.582909 0.428070 0
8 0.654652 0.673884 0.582909 0.428070 0
4 0.493081 0.455021 0.352437 0.971354 1
3 0.793327 0.377768 0.008651 0.583467 1
print(test)
A B C D y
2 0.403431 0.356454 0.780375 0.987747 1
5 0.706505 0.650936 0.532032 0.791598 1
6 0.414343 0.424478 0.802620 0.584577 1
7 0.017409 0.267225 0.740127 0.050121 1
9 0.515559 0.262651 0.339282 0.394977 1