LightGBM建模,sklearn评估
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV
print('加载数据...')
df_train = pd.read_csv('./data/regression.train.txt', header=None, sep='\t')
df_test = pd.read_csv('./data/regression.test.txt', header=None, sep='\t')
y_train = df_train[0].values
y_test = df_test[0].values
X_train = df_train.drop(0, axis=1).values
X_test = df_test.drop(0, axis=1).values
print('开始训练...')
gbm = lgb.LGBMRegressor(objective='regression',
num_leaves=31,
learning_rate=0.05,
n_estimators=20)
gbm.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
eval_metric='l1',
early_stopping_rounds=5)
print('开始预测...')
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
print('预测结果的rmse是:')
print(mean_squared_error(y_test, y_pred) ** 0.5)
网格搜索查找最优超参数
estimator = lgb.LGBMRegressor(num_leaves=31)
param_grid = {
'learning_rate': [0.01, 0.1, 1],
'n_estimators': [20, 40]
}
gbm = GridSearchCV(estimator, param_grid)
gbm.fit(X_train, y_train)
print('用网格搜索找到的最优超参数为:')
print(gbm.best_params_)
绘图解释
import lightgbm as lgb
import pandas as pd
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You need to install matplotlib for plotting.')
print('加载数据...')
df_train = pd.read_csv('./data/regression.train.txt', header=None, sep='\t')
df_test = pd.read_csv('./data/regression.test.txt', header=None, sep='\t')
y_train = df_train[0].values
y_test = df_test[0].values
X_train = df_train.drop(0, axis=1).values
X_test = df_test.drop(0, axis=1).values
lgb_train = lgb.Dataset(X_train, y_train)
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
params = {
'num_leaves': 5,
'metric': ('l1', 'l2'),
'verbose': 0
}
evals_result = {}
print('开始训练...')
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_sets=[lgb_train, lgb_test],
feature_name=['f' + str(i + 1) for i in range(28)],
categorical_feature=[21],
evals_result=evals_result,
verbose_eval=10)
print('在训练过程中绘图...')
ax = lgb.plot_metric(evals_result, metric='l1')
plt.show()
print('画出特征重要度...')
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
print('画出第84颗树...')
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show()