xgboost的原生版本与sklearn 接口版本对比

  • 1 准备数据
  • 2 原生xgboost
  • 3 xgboost的sklearn接口

xgboost的python版本有原生版本和为了与sklearn相适应的sklearn接口版本
原生版本更灵活,而sklearn版本能够使用sklearn的Gridsearch,二者互有优缺,现使用sklearn自带的boston数据集做简单对比如下:

1 准备数据

#导入包
from sklearn import datasets
import pandas as pd
import xgboost as xgb
from sklearn.metrics import mean_squared_error

#使用sklearn自带的boston数据集
boston_data = datasets.load_boston()
df_boston = pd.DataFrame(boston_data.data,columns=boston_data.feature_names)
df_boston['target'] = pd.Series(boston_data.target)

xtrain = df_boston.head(500).drop(['target'],axis=1)
ytrain = df_boston.head(500).target

xtest = df_boston.tail(6).drop(['target'],axis=1)
ytest = df_boston.tail(6).target

2 原生xgboost

params = {
    'eta':0.1,
    'max_depth':2,
    'min_child_weight':3,
    'gamma':0,
    'subsample':.8,
    'colsample_bytree':.7,
    'reg_alpha':1,
    'objective':'reg:linear'
}
dtrain = xgb.DMatrix(xtrain,ytrain)
dtest = xgb.DMatrix(xtest,ytest)
watchlist1 = [(dtrain,'train'),(dtest,'test')]

model1 = xgb.train(params=params,dtrain=dtrain,num_boost_round=100,early_stopping_rounds=10,evals=watchlist1)

训练过程:

[0] train-rmse:21.7093  test-rmse:17.3474
Multiple eval metrics have been passed: 'test-rmse' will be used for early stopping.

Will train until test-rmse hasn't improved in 10 rounds.
[1] train-rmse:19.6869  test-rmse:15.0228
[2] train-rmse:17.8982  test-rmse:13.2654
[3] train-rmse:16.2475  test-rmse:11.6048
[4] train-rmse:14.7695  test-rmse:10.0835
[5] train-rmse:13.5417  test-rmse:8.49356
[6] train-rmse:12.3675  test-rmse:7.26111
[7] train-rmse:11.3112  test-rmse:6.23273
[8] train-rmse:10.3767  test-rmse:5.15305
[9] train-rmse:9.54068  test-rmse:4.24451
[10]    train-rmse:8.80674  test-rmse:3.72274
[11]    train-rmse:8.10549  test-rmse:3.28673
[12]    train-rmse:7.5088   test-rmse:3.17986
[13]    train-rmse:6.98417  test-rmse:3.18837
[14]    train-rmse:6.51555  test-rmse:3.28654
[15]    train-rmse:6.10498  test-rmse:3.4315
[16]    train-rmse:5.74395  test-rmse:3.51674
[17]    train-rmse:5.41198  test-rmse:3.65761
[18]    train-rmse:5.11477  test-rmse:3.84013
[19]    train-rmse:4.86481  test-rmse:3.97942
[20]    train-rmse:4.63408  test-rmse:4.15275
[21]    train-rmse:4.44914  test-rmse:4.26283
[22]    train-rmse:4.27886  test-rmse:4.467
Stopping. Best iteration:
[12]    train-rmse:7.5088   test-rmse:3.17986
result1 = model1.predict(xgb.DMatrix(xtest),ntree_limit=model1.best_iteration)
mse1 = mean_squared_error(ytest,result1)

预测结果result1为:array([15.062089, 18.013107, 17.307507, 24.263498, 20.47144 , 17.307507],dtype=float32)

对应最小二乘方差mse110.802571870358932

3 xgboost的sklearn接口

from xgboost.sklearn import XGBRegressor
model2 = XGBRegressor(
    learn_rate = 0.1,
    max_depth = 2,
    min_child_weight = 3,
    gamma = 0,
    subsample = 0.8,
    colsample_bytree = 0.7,
    reg_alpha = 1,
    objective = 'reg:linear',
    n_estimators = 100
)
watchlist2 = [(xtrain,ytrain),(xtest,ytest)]
model2.fit(xtrain,ytrain,eval_set=watchlist2,early_stopping_rounds=10)

训练过程:(与原声接口的过程相同)

[0] validation_0-rmse:21.7093   validation_1-rmse:17.3474
Multiple eval metrics have been passed: 'validation_1-rmse' will be used for early stopping.

Will train until validation_1-rmse hasn't improved in 10 rounds.
[1] validation_0-rmse:19.6869   validation_1-rmse:15.0228
[2] validation_0-rmse:17.8982   validation_1-rmse:13.2654
[3] validation_0-rmse:16.2475   validation_1-rmse:11.6048
[4] validation_0-rmse:14.7695   validation_1-rmse:10.0835
[5] validation_0-rmse:13.5417   validation_1-rmse:8.49356
[6] validation_0-rmse:12.3675   validation_1-rmse:7.26111
[7] validation_0-rmse:11.3112   validation_1-rmse:6.23273
[8] validation_0-rmse:10.3767   validation_1-rmse:5.15305
[9] validation_0-rmse:9.54068   validation_1-rmse:4.24451
[10]    validation_0-rmse:8.80674   validation_1-rmse:3.72274
[11]    validation_0-rmse:8.10549   validation_1-rmse:3.28673
[12]    validation_0-rmse:7.5088    validation_1-rmse:3.17986
[13]    validation_0-rmse:6.98417   validation_1-rmse:3.18837
[14]    validation_0-rmse:6.51555   validation_1-rmse:3.28654
[15]    validation_0-rmse:6.10498   validation_1-rmse:3.4315
[16]    validation_0-rmse:5.74395   validation_1-rmse:3.51674
[17]    validation_0-rmse:5.41198   validation_1-rmse:3.65761
[18]    validation_0-rmse:5.11477   validation_1-rmse:3.84013
[19]    validation_0-rmse:4.86481   validation_1-rmse:3.97942
[20]    validation_0-rmse:4.63408   validation_1-rmse:4.15275
[21]    validation_0-rmse:4.44914   validation_1-rmse:4.26283
[22]    validation_0-rmse:4.27886   validation_1-rmse:4.467
Stopping. Best iteration:
[12]    validation_0-rmse:7.5088    validation_1-rmse:3.17986
result2 = model2.predict(xtest,ntree_limit=model2.best_iteration)
mse2 = mean_squared_error(ytest,result2)

预测结果result2为:array([15.062089, 18.013107, 17.307507, 24.263498, 20.47144 , 17.307507],dtype=float32)

对应最小二乘方差mse210.802571870358932

对比预测结果,原生xgb与sklearn接口的训练过程相同,结果也相同。
不同之处在于:
1. 原生采用xgb.train()训练,sklearn接口采用model.fit()
2. sklearn接口中的参数n_estimators在原生xgb中定义在xgb.train()num_boost_round
3. sklearnwatchlist[(xtrain,ytrain),(xtest,ytest)]形式,而原生则是ain,'train'),(dtest,'test')],在数据和标签都在DMatrix中,元组里可以定位输出时的名字

你可能感兴趣的:(机器学习)