模型调整, 假设已经找到了一些潜在的模型,下面是几种方法用于模型调整
class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=None, iid=’warn’, refit=True, cv=’warn’, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise-deprecating’, return_train_score=’warn’)
属性:
best_params
cv_results_
from sklearn.model_selection import GridSearchCV
param_grid = [
{'n_estimators': [3, 10, 30], 'max_features': [2, 4, 6, 8]},
{'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]},
]
forest_reg = RandomForestRegressor()
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,
scoring='neg_mean_squared_error')
grid_search.fit(housing_prepared, housing_labels)
sklearn 根据param_grid的值,首先会评估 3 × 4 = 12 3 \times 4 = 12 3×4=12种n_estimators和max_features的组合方式,接下来在会在bootstrap=False的情况下(默认该值为True),评估 2 × 3 = 6 2 \times3 =6 2×3=6种12种n_estimators和max_features的组合方式,所以最终会有 12 + 6 = 18 12+6=18 12+6=18种不同的超参数组合方式,
而每一种组合方式要在训练集上训练5次, 所以一共要训练 18 × 5 = 90 18 \times 5 = 90 18×5=90次,当训练结束后,你可以通过best_params_获得最好的组合方式
grid_search.best_params_
out:
{‘max_features’: 8, ‘n_estimators’: 30}
得到最好的模型
grid_search.best_estimator_
out:
RandomForestRegressor(bootstrap=True, criterion=‘mse’, max_depth=None,
max_features=8, max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=30, n_jobs=1, oob_score=False, random_state=None,
verbose=0, warm_start=False)
如果GridSearchCV初始化时,refit=True(默认的初始化值),在交叉验证时,一旦发现最好的模型(estimator),将会在整个训练集上重新训练,这通常是一个好主意,因为使用更多的数据集会提升模型的性能。
cv_results_:将结果存在一个字典里, 可以转化为DataFrame类型,每一行为一种超参数组合方式。
cv = pd.DataFrame(grid_search.cv_results_
cv
# out:
mean_fit_time std_fit_time mean_score_time std_score_time param_max_features param_n_estimators param_bootstrap params split0_test_score split1_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 0.085940 0.025568 0.004391 0.001827 2 3 NaN {'max_features': 2, 'n_estimators': 3} -3.827812e+09 -4.092971e+09 ... -4.139394e+09 1.959892e+08 18 -1.108742e+09 -1.076285e+09 -1.151262e+09 -1.127172e+09 -1.114365e+09 -1.115565e+09 2.449443e+07
1 0.237965 0.021394 0.010079 0.000971 2 10 NaN {'max_features': 2, 'n_estimators': 10} -2.742609e+09 -3.333789e+09 ... -3.113304e+09 2.395274e+08 11 -5.701092e+08 -6.094081e+08 -5.785905e+08 -5.931616e+08 -5.756676e+08 -5.853874e+08 1.422342e+07
2 0.767676 0.092776 0.029941 0.004852 2 30 NaN {'max_features': 2, 'n_estimators': 30} -2.715244e+09 -2.902911e+09 ... -2.802893e+09 1.542347e+08 8 -4.472685e+08 -4.271909e+08 -4.250249e+08 -4.289724e+08 -4.274656e+08 -4.311845e+08 8.140110e+06
3 0.107399 0.002467 0.003590 0.000482 4 3 NaN {'max_features': 4, 'n_estimators': 3} -3.698121e+09 -3.903447e+09 ... -3.666869e+09 2.123558e+08 15 -9.640156e+08 -9.257596e+08 -1.003782e+09 -9.539664e+08 -8.817442e+08 -9.458535e+08 4.065737e+07
4 0.348854 0.003953 0.009683 0.000399 4 10 NaN {'max_features': 4, 'n_estimators': 10} -2.721300e+09 -2.868056e+09 ... -2.788953e+09 1.090418e+08 7 -5.299723e+08 -4.985359e+08 -4.892567e+08 -5.204567e+08 -5.260378e+08 -5.128519e+08 1.604050e+07
5 1.058535 0.007244 0.027150 0.000750 4 30 NaN {'max_features': 4, 'n_estimators': 30} -2.412966e+09 -2.617706e+09 ... -2.575291e+09 1.220214e+08 3 -4.002692e+08 -3.978664e+08 -3.819384e+08 -4.050417e+08 -3.929495e+08 -3.956131e+08 7.870790e+06
6 0.157606 0.014043 0.003394 0.000200 6 3 NaN {'max_features': 6, 'n_estimators': 3} -3.292618e+09 -3.698856e+09 ... -3.535907e+09 1.405519e+08 14 -8.680883e+08 -9.112255e+08 -9.112524e+08 -1.004211e+09 -9.125849e+08 -9.214724e+08 4.468549e+07
7 0.491094 0.019871 0.009585 0.000372 6 10 NaN {'max_features': 6, 'n_estimators': 10} -2.553328e+09 -2.791239e+09 ... -2.750653e+09 1.162331e+08 6 -4.944491e+08 -4.779433e+08 -4.857806e+08 -5.040733e+08 -5.053104e+08 -4.935113e+08 1.052407e+07
8 1.446224 0.011595 0.027342 0.000976 6 30 NaN {'max_features': 6, 'n_estimators': 30} -2.358673e+09 -2.551963e+09 ... -2.516006e+09 1.406579e+08 2 -3.919785e+08 -3.910130e+08 -3.912702e+08 -3.954620e+08 -3.815968e+08 -3.902641e+08 4.618501e+06
9 0.184060 0.003479 0.003398 0.000202 8 3 NaN {'max_features': 8, 'n_estimators': 3} -3.214309e+09 -3.749000e+09 ... -3.492941e+09 2.259341e+08 13 -8.754845e+08 -9.164619e+08 -9.071570e+08 -9.245905e+08 -8.557108e+08 -8.958809e+08 2.609451e+07
10 0.613070 0.003497 0.009779 0.000676 8 10 NaN {'max_features': 8, 'n_estimators': 10} -2.524796e+09 -2.749792e+09 ... -2.695701e+09 1.184566e+08 4 -4.814605e+08 -4.924878e+08 -4.953451e+08 -5.081368e+08 -5.055688e+08 -4.965998e+08 9.604441e+06
11 1.843982 0.014717 0.026444 0.000943 8 30 NaN {'max_features': 8, 'n_estimators': 30} -2.375197e+09 -2.517777e+09 ... -2.506430e+09 1.330733e+08 1 -3.884794e+08 -3.844676e+08 -3.707626e+08 -3.920550e+08 -3.907572e+08 -3.853043e+08 7.714260e+06
12 0.105807 0.003103 0.003793 0.000399 2 3 False {'bootstrap': False, 'max_features': 2, 'n_est... -3.771914e+09 -3.740538e+09 ... -3.750444e+09 9.489691e+07 17 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00
13 0.347154 0.003689 0.010282 0.000509 2 10 False {'bootstrap': False, 'max_features': 2, 'n_est... -2.674186e+09 -2.990817e+09 ... -2.920952e+09 1.338905e+08 10 -4.739042e+02 -1.947960e+03 -1.514005e+02 -0.000000e+00 -7.418622e-01 -5.148012e+02 7.371504e+02
14 0.137443 0.003622 0.003394 0.000200 3 3 False {'bootstrap': False, 'max_features': 3, 'n_est... -3.262696e+09 -3.673369e+09 ... -3.667646e+09 3.201594e+08 16 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00
15 0.458953 0.010759 0.010680 0.000400 3 10 False {'bootstrap': False, 'max_features': 3, 'n_est... -2.672303e+09 -2.797555e+09 ... -2.819182e+09 1.591897e+08 9 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00
16 0.170678 0.001511 0.004197 0.000249 4 3 False {'bootstrap': False, 'max_features': 4, 'n_est... -3.360936e+09 -3.505000e+09 ... -3.424422e+09 1.686945e+08 12 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00
17 0.566040 0.004567 0.011083 0.000581 4 10 False {'bootstrap': False, 'max_features': 4, 'n_est... -2.507133e+09 -2.778602e+09 ... -2.701759e+09 1.294327e+08 5 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 -0.000000e+00 0.000000e+00 0.000000e+00
18 rows × 23 columns
可以看出,一共有18行,代表18中参数的组合方式。
将超参数参数的组合方式和想应的误差值
for mean_score, params in zip(cv["mean_test_score"], cv["params"]):
print(np.sqrt(-mean_score), params)
# out:
64338.12255034114 {'max_features': 2, 'n_estimators': 3}
55796.98682826028 {'max_features': 2, 'n_estimators': 10}
52942.353620944974 {'max_features': 2, 'n_estimators': 30}
60554.68102018701 {'max_features': 4, 'n_estimators': 3}
52810.54319153441 {'max_features': 4, 'n_estimators': 10}
50747.320348497095 {'max_features': 4, 'n_estimators': 30}
59463.49112391854 {'max_features': 6, 'n_estimators': 3}
52446.67106233131 {'max_features': 6, 'n_estimators': 10}
50159.80848511047 {'max_features': 6, 'n_estimators': 30}
59101.107722861816 {'max_features': 8, 'n_estimators': 3}
51920.13895373909 {'max_features': 8, 'n_estimators': 10}
50064.25380671872 {'max_features': 8, 'n_estimators': 30}
61240.86657596463 {'bootstrap': False, 'max_features': 2, 'n_estimators': 3}
54045.83360004432 {'bootstrap': False, 'max_features': 2, 'n_estimators': 10}
60561.09106204302 {'bootstrap': False, 'max_features': 3, 'n_estimators': 3}
53095.965054469176 {'bootstrap': False, 'max_features': 3, 'n_estimators': 10}
58518.56200119758 {'bootstrap': False, 'max_features': 4, 'n_estimators': 3}
51978.44739222989 {'bootstrap': False, 'max_features': 4, 'n_estimators': 10}
最好的是’max_features’: 8, ‘n_estimators’: 30, 误差为50064,结果要好于默认的超参数,误差为52634
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html