1.np.random.permutation法
import numpy as np
def split_train_test(data, test_ratio):
shuffled_indices = np.random.permutation(len(data))
test_set_size = int(len(data) * test_ratio)
test_indices = shuffled_indices[:test_set_size]
train_indices = shuffled_indices[test_set_size:]
return data.iloc[train_indices], data.iloc[test_indices]
train_set, test_set = split_train_test(housing, 0.2)
print('len_train_set:', len(train_set))
print('len_test_set:', len(test_set))
#输出
len_train_set: 16512
len_test_set: 4128
缺点:
a.数据集分割在变化
b.使用np.random.seed(42)可以保持数据集的分隔,但是数据集变化后,分割变化
2.crc32法
from zlib import crc32
def test_set_check(identifier, test_ratio):
return crc32(np.int64(identifier)) & 0xffffffff < test_ratio * 2**32
def split_train_test_by_id(data, test_ratio, id_column):
ids = data[id_column]
in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio))
return data.loc[~in_test_set], data.loc[in_test_set]
housing_with_id = housing.reset_index()
housing_with_id["id"] = housing["longitude"] * 1000 + housing["latitude"]
train_test, test_set = split_train_test_by_id(housing_with_id, 0.2, "id")
print('len_train_set:', len(train_set))
print('len_test_set:', len(test_set))
#输出
len_train_set: 16512
len_test_set: 4318
#缺点
依赖于唯一的id
3.sklearn法
from sklearn.model_selection import train_test_split
train_set, test_set = train_test_split(housing, test_size=0.2, random_state=42)
print('len_train_set:', len(train_set))
print('len_test_set:', len(test_set))
#输出
len_train_set: 16512
len_test_set: 4128