最近看机器学习的教学视频,老师反复提到了一个函数GridSearchCV()。举个例子,在python中用一个模型的时候,可能会涉及一些需要人为指定的参数,比如随机森林模型需要指定min_samples_split=?、n_estimators=?,在我们缺乏先验知识的时候,我们是不知道取什么样的值才是合适的,这个时候GridSearchCV()函数就派上了用场。
#简单的例子来看看GridSearchCV函数的用处
from sklearn.grid_search import GridSearchCV
tree_param_grid = { 'min_samples_split': list((3,6,9)),'n_estimators':list((10,50,100))}
grid = GridSearchCV(RandomForestRegressor(),param_grid=tree_param_grid, cv=5)
grid.fit(data_train, target_train)
grid.grid_scores_, grid.best_params_, grid.best_score_
输出结果:
([mean: 0.78656, std: 0.00429, params: {'min_samples_split': 3, 'n_estimators': 10}, mean: 0.80391, std: 0.00372, params: {'min_samples_split': 3, 'n_estimators': 50}, mean: 0.80843, std: 0.00348, params: {'min_samples_split': 3, 'n_estimators': 100}, mean: 0.78668, std: 0.00335, params: {'min_samples_split': 6, 'n_estimators': 10}, mean: 0.80592, std: 0.00324, params: {'min_samples_split': 6, 'n_estimators': 50}, mean: 0.80724, std: 0.00401, params: {'min_samples_split': 6, 'n_estimators': 100}, mean: 0.79123, std: 0.00194, params: {'min_samples_split': 9, 'n_estimators': 10}, mean: 0.80344, std: 0.00553, params: {'min_samples_split': 9, 'n_estimators': 50}, mean: 0.80576, std: 0.00450, params: {'min_samples_split': 9, 'n_estimators': 100}], {'min_samples_split': 3, 'n_estimators': 100}, 0.8084252204574226)
可以看到,首先我们是在sklearn的库中导入GridSearchCV()函数,然后生成一个字典供后续的调用,这个字典包括各种参数名、候选参数值列表。然后就可以用这个函数啦,在这个函数中传入你想拟合的模型,以及前面得到的字典,cv指定的是交叉验证次数,为什么要进行交叉验证,这是因为交叉验证的结果就是评判选择最好的参数的理由,这里涉及到分数,得分越高,系统会偏向选择这组参数,上面小例子中,系统就是选择'min_samples_split': 3, 'n_estimators': 100,因为得分最高,是0.808425
这里附上这个函数的官方文档,里面的介绍和用法非常全面,还有不同的例子,能够全面直观的学习到这个函数:http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV
至于交叉验证这个内容,本文章就不细说了,明白GridSearchCV()函数的用法是本文章的目的。