XGBoost模型保存与读取(多分类问题)

 在用XGBClassifier做多分类问题模型存取时,采用save_model与load_model函数发现并不是很好用,因此通过pickle进行模型的存取工作,在此记录,以备后用。

import pickle
from xgboost import XGBClassifier

#train

model_xg = XGBClassifier(
        n_estimators=20,
        learning_rate=0.1,
        max_depth=8,
        subsample=0.8,
        early_stopping_rounds = 50,
        objective='multi:softmax',
        eval_metric = 'mlogloss')
model_xg.fit(x_train, y_train,verbose=True)

# save
pickle.dump(model_xg, open("xgb.pkl", "wb"))

# load
xgb_model_loaded = pickle.load(open("xgb.pkl", "rb"))

# test
xgb_model_loaded.predict(test)

 

你可能感兴趣的:(机器学习,XGBoost)