使用Xgboost自带的读取格式DMatrix()

参考:https://cloud.tencent.com/developer/article/1466793

可以接受的格式:

·LibSVM文本格式文件
·逗号分隔值(CSV)文件
·NumPy 2D阵列
·SciPy 2D稀疏阵列
·DataFrame数据框
·XGBoost二进制缓冲区文件

需要注意的是:XGBoost不支持分类功能; 如果您的数据包含分类功能,请先将其加载为NumPy阵列,然后执行onehot编码。
XGBoost无法解析带有标头的CSV文件。

参数设定

XGBoost可以使用列表或字典来设置参数,如下所示:

param = {
     'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
param['nthread'] = 4
param['eval_metric'] = 'auc'
param['eval_metric'] = ['auc', 'ams@0']

训练

bst = xgb.train(param, dtrain, num_round, evallist)

预测

dtest = xgb.DMatrix(data)
ypred = bst.predict(dtest)

你可能感兴趣的:(XGboost)