机器学习 sklearn 中的超参数搜索方法

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
个人主页:小嗷犬的个人主页
个人网站:小嗷犬的技术小站
个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 超参数搜索
    • 默认参数
    • GridSearchCV
    • RandomizedSearchCV
    • HalvingGridSearchCV
    • HalvingRandomSearchCV


超参数搜索

在建模时模型的超参数往往会对精度造成一定影响,而设置和调整超参数的取值,往往称为调参

在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜索,找到最优的超参数组合。

在 sklearn 中,提供了四种超参数搜索方法:

  • GridSearchCV
  • RandomizedSearchCV
  • HalvingGridSearchCV
  • HalvingRandomSearchCV

默认参数

为了方便起见,我们先定义一个默认参数的模型,用于后续的超参数搜索。

# 导入相关库
import random
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import (
    train_test_split,
    GridSearchCV,
    RandomizedSearchCV,
    HalvingGridSearchCV,
    HalvingRandomSearchCV,
)
from sklearn.ensemble import RandomForestRegressor

# 设置随机种子
seed = 1
random.seed(seed)
np.random.seed(seed)

# 加载数据集
data = datasets.load_diabetes()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)

# 使用默认参数模型进行分类并评分
reg = RandomForestRegressor(random_state=seed)
reg.fit(X_train, y_train)
print(round(reg.score(X_test, y_test), 6))

最终默认参数的模型在测试集上的 R 2 R^2 R2 分数约为 0.269413

GridSearchCV

GridSearchCV 是一种网格搜索超参数的方法,它会遍历所有的超参数组合,然后评估模型的性能,最终选择性能最好的一组超参数。

# 设置超参数搜索范围
param_grid = {
    "max_depth": [2, 4, 5, 6, 7],
    "min_samples_leaf": [1, 2, 3],
    "min_weight_fraction_leaf": [0, 0.1],
    "min_impurity_decrease": [0, 0.1, 0.2]
}

# 使用 GridSearchCV 进行超参数搜索
reg = GridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

在这个超参数搜索空间中,一共有 5 * 3 * 2 * 3 = 90 种超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

RandomizedSearchCV

RandomizedSearchCV 是一种随机搜索超参数的方法,它的使用方法与 GridSearchCV 类似,但是它不会遍历所有的超参数组合,而是在超参数的取值范围内随机选择一组超参数进行训练,然后评估模型的性能,最终选择性能最好的一组超参数。

# 使用 RandomizedSearchCV 进行超参数搜索
reg = RandomizedSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    n_iter=20,  # 设置迭代次数
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

RandomizedSearchCV 一共进行了 20 次迭代,即尝试了 20 组超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

HalvingGridSearchCV

HalvingGridSearchCVGridSearchCV 类似,但在迭代的过程中采用减半超参数搜索空间的方法,以此来减少超参数搜索的时间。

在搜索的最开始,HalvingGridSearchCV 使用很少的数据样本来在完整的超参数搜索空间中进行搜索,筛选其中最优的超参数,之后再增加数据进行进一步筛选。

# 使用 HalvingGridSearchCV 进行超参数搜索
reg = HalvingGridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

可以看到,HalvingGridSearchCV 得到的最优超参数组合与 GridSearchCV 得到的最优超参数组合相同。

HalvingRandomSearchCV

HalvingRandomSearchCVHalvingGridSearchCV 类似,都是逐步增加样本数量,减少超参数组合,但是 HalvingRandomSearchCV 每次生成的超参数组合是随机的。

# 使用 HalvingRandomSearchCV 进行超参数搜索
reg = HalvingRandomSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

你可能感兴趣的:(Python,机器学习,机器学习,sklearn,人工智能)