xgboost的python版本有原生版本和为了与sklearn相适应的sklearn接口版本
原生版本更灵活,而sklearn版本能够使用sklearn的Gridsearch,二者互有优缺,现使用sklearn自带的boston数据集做简单对比如下:
#导入包
from sklearn import datasets
import pandas as pd
import xgboost as xgb
from sklearn.metrics import mean_squared_error
#使用sklearn自带的boston数据集
boston_data = datasets.load_boston()
df_boston = pd.DataFrame(boston_data.data,columns=boston_data.feature_names)
df_boston['target'] = pd.Series(boston_data.target)
xtrain = df_boston.head(500).drop(['target'],axis=1)
ytrain = df_boston.head(500).target
xtest = df_boston.tail(6).drop(['target'],axis=1)
ytest = df_boston.tail(6).target
params = {
'eta':0.1,
'max_depth':2,
'min_child_weight':3,
'gamma':0,
'subsample':.8,
'colsample_bytree':.7,
'reg_alpha':1,
'objective':'reg:linear'
}
dtrain = xgb.DMatrix(xtrain,ytrain)
dtest = xgb.DMatrix(xtest,ytest)
watchlist1 = [(dtrain,'train'),(dtest,'test')]
model1 = xgb.train(params=params,dtrain=dtrain,num_boost_round=100,early_stopping_rounds=10,evals=watchlist1)
训练过程:
[0] train-rmse:21.7093 test-rmse:17.3474
Multiple eval metrics have been passed: 'test-rmse' will be used for early stopping.
Will train until test-rmse hasn't improved in 10 rounds.
[1] train-rmse:19.6869 test-rmse:15.0228
[2] train-rmse:17.8982 test-rmse:13.2654
[3] train-rmse:16.2475 test-rmse:11.6048
[4] train-rmse:14.7695 test-rmse:10.0835
[5] train-rmse:13.5417 test-rmse:8.49356
[6] train-rmse:12.3675 test-rmse:7.26111
[7] train-rmse:11.3112 test-rmse:6.23273
[8] train-rmse:10.3767 test-rmse:5.15305
[9] train-rmse:9.54068 test-rmse:4.24451
[10] train-rmse:8.80674 test-rmse:3.72274
[11] train-rmse:8.10549 test-rmse:3.28673
[12] train-rmse:7.5088 test-rmse:3.17986
[13] train-rmse:6.98417 test-rmse:3.18837
[14] train-rmse:6.51555 test-rmse:3.28654
[15] train-rmse:6.10498 test-rmse:3.4315
[16] train-rmse:5.74395 test-rmse:3.51674
[17] train-rmse:5.41198 test-rmse:3.65761
[18] train-rmse:5.11477 test-rmse:3.84013
[19] train-rmse:4.86481 test-rmse:3.97942
[20] train-rmse:4.63408 test-rmse:4.15275
[21] train-rmse:4.44914 test-rmse:4.26283
[22] train-rmse:4.27886 test-rmse:4.467
Stopping. Best iteration:
[12] train-rmse:7.5088 test-rmse:3.17986
result1 = model1.predict(xgb.DMatrix(xtest),ntree_limit=model1.best_iteration)
mse1 = mean_squared_error(ytest,result1)
预测结果result1
为:array([15.062089, 18.013107, 17.307507, 24.263498, 20.47144 , 17.307507],dtype=float32)
对应最小二乘方差mse1
为 10.802571870358932
from xgboost.sklearn import XGBRegressor
model2 = XGBRegressor(
learn_rate = 0.1,
max_depth = 2,
min_child_weight = 3,
gamma = 0,
subsample = 0.8,
colsample_bytree = 0.7,
reg_alpha = 1,
objective = 'reg:linear',
n_estimators = 100
)
watchlist2 = [(xtrain,ytrain),(xtest,ytest)]
model2.fit(xtrain,ytrain,eval_set=watchlist2,early_stopping_rounds=10)
训练过程:(与原声接口的过程相同)
[0] validation_0-rmse:21.7093 validation_1-rmse:17.3474
Multiple eval metrics have been passed: 'validation_1-rmse' will be used for early stopping.
Will train until validation_1-rmse hasn't improved in 10 rounds.
[1] validation_0-rmse:19.6869 validation_1-rmse:15.0228
[2] validation_0-rmse:17.8982 validation_1-rmse:13.2654
[3] validation_0-rmse:16.2475 validation_1-rmse:11.6048
[4] validation_0-rmse:14.7695 validation_1-rmse:10.0835
[5] validation_0-rmse:13.5417 validation_1-rmse:8.49356
[6] validation_0-rmse:12.3675 validation_1-rmse:7.26111
[7] validation_0-rmse:11.3112 validation_1-rmse:6.23273
[8] validation_0-rmse:10.3767 validation_1-rmse:5.15305
[9] validation_0-rmse:9.54068 validation_1-rmse:4.24451
[10] validation_0-rmse:8.80674 validation_1-rmse:3.72274
[11] validation_0-rmse:8.10549 validation_1-rmse:3.28673
[12] validation_0-rmse:7.5088 validation_1-rmse:3.17986
[13] validation_0-rmse:6.98417 validation_1-rmse:3.18837
[14] validation_0-rmse:6.51555 validation_1-rmse:3.28654
[15] validation_0-rmse:6.10498 validation_1-rmse:3.4315
[16] validation_0-rmse:5.74395 validation_1-rmse:3.51674
[17] validation_0-rmse:5.41198 validation_1-rmse:3.65761
[18] validation_0-rmse:5.11477 validation_1-rmse:3.84013
[19] validation_0-rmse:4.86481 validation_1-rmse:3.97942
[20] validation_0-rmse:4.63408 validation_1-rmse:4.15275
[21] validation_0-rmse:4.44914 validation_1-rmse:4.26283
[22] validation_0-rmse:4.27886 validation_1-rmse:4.467
Stopping. Best iteration:
[12] validation_0-rmse:7.5088 validation_1-rmse:3.17986
result2 = model2.predict(xtest,ntree_limit=model2.best_iteration)
mse2 = mean_squared_error(ytest,result2)
预测结果result2
为:array([15.062089, 18.013107, 17.307507, 24.263498, 20.47144 , 17.307507],dtype=float32)
对应最小二乘方差mse2
为 10.802571870358932
对比预测结果,原生xgb与sklearn接口的训练过程相同,结果也相同。
不同之处在于:
1. 原生采用xgb.train()
训练,sklearn接口采用model.fit()
。
2. sklearn接口中的参数n_estimators在原生xgb中定义在xgb.train()
的num_boost_round
3. sklearnwatchlist
为[(xtrain,ytrain),(xtest,ytest)]
形式,而原生则是ain,'train'),(dtest,'test')]
,在数据和标签都在DMatrix中,元组里可以定位输出时的名字