LGB + K-fold Cross Validation 用法小记

看到好的代码,就记录下来,方便以后使用

# 引入相关包
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import StratifiedKFold

# 假设这里准备好了训练数据train_data,它是一个pandas的dataframe,包括特征列和score列
train_label = train_data['score']

# 初始化一个k-fold生成器
NFOLDS = 5
kfold = StratifiedKFold(n_splits=NFOLDS, shuffle=True, random_state=2019)
kf = kfold.split(train_data, train_label)
cv_pred = np.zeros(test_data.shape[0])
valid_best_l2_all = 0
feature_importance_df = pd.DataFrame()

# 执行训练
for i, (train_fold, validate) in enumerate(kf):
    X_train, X_validate, label_train, label_validate = \
    train_data.iloc[train_fold, :], train_data.iloc[validate, :], \
    train_label[train_fold], train_label[validate]
    
    dtrain = lgb.Dataset(X_train, label_train)
    dvalid = lgb.Dataset(X_validate, label_validate, reference=dtrain)
    
    bst = lgb.train(params, dtrain, num_boost_round=10000, valid_sets=dvalid, verbose_eval=-1,early_stopping_rounds=50)
    
	cv_pred += bst.predict(test_data_use, num_iteration=bst.best_iteration)
    valid_best_l2_all += bst.best_score['valid_0']['l1']
    
    fold_importance_df = pd.DataFrame()
    fold_importance_df["feature"] = list(X_train.columns)
    fold_importance_df["importance"] = bst.feature_importance(importance_type='gain', iteration=bst.best_iteration)
    fold_importance_df["fold"] = count + 1
    feature_importance_df = pd.concat([feature_importance_df, fold_importance_df], axis=0)
    count += 1

cv_pred /= NFOLDS
valid_best_l2_all /= NFOLDS
print('cv score for valid is: ', 1/(1+valid_best_l2_all))

def display_importances(feature_importance_df_):
    cols = feature_importance_df_[["feature", "importance"]].groupby("feature").mean().sort_values(by="importance", ascending=False)[:40].index
    best_features = feature_importance_df_.loc[feature_importance_df_.feature.isin(cols)]
    plt.figure(figsize=(8, 10))
    sns.barplot(x="importance", y="feature", data=best_features.sort_values(by="importance", ascending=False))
    plt.title('LightGBM Features (avg over folds)')
    plt.tight_layout()
    plt.show()
    
display_importances(feature_importance_df)

你可能感兴趣的:(LGB + K-fold Cross Validation 用法小记)