机器学习调参神器--网格搜索

​超参数是模型中的参数中不能通过学习得到的参数。在scikit-learn中,典型的例子有支持向量分类器的参数C,kernel和gamma,Lasso的参数alpha等。在超参数集中搜索以获得最佳交叉验证分数的方法是可实现并且推荐的,网格搜索GridSearchCV应运而生!

实例

以支持向量机模型为例,训练鸢尾花数据集,搜索最优参数组合C和gamma。

from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV,train_test_split

# 导入鸢尾花数据
iris = datasets.load_iris()

# 设置待搜索参数及其参数值
param_grid = {'C':[0.001,0.01,0.1,1,10,100],
              'gamma':[0.001,0.01,0.1,1,10,100]}

# 建立网格搜索模型
grid = GridSearchCV(SVC(),param_grid,cv=5,return_train_score=True)

# 划分数据集
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=0)

# 拟合训练集
grid.fit(X_train,y_train)

# 最优模型在测试集的得分
grid.score(X_test,y_test)

# 最优参数模型
grid.best_estimator_

# 最优参数组合
grid.best_params_

分析

上述的参数C和gamma分别有6个不同的取值,所以有36种不同的参数组合,利用GridSearchCV分别对训练集进行交叉验证评估,筛选出最优的参数组合,并且自动生成效果最好的最优模型,此乃神器也~~~~

哈哈哈哈哈哈哈哈哈哈哈

你可能感兴趣的:(XGBoost入门与实践,机器学习,sklearn,人工智能,python,大数据)