lightgbm在train的时候有callback的接口,我们需要将训练过程的损失下降情况进行记录就需要这个接口。本文笔者就是以记录训练迭代过程的损失为出发点,写一个简单的lightgbm中callback的使用方法。
(function) train: (params:
… callbacks: List[(…) -> Any] | None = None) -> Booster
callbacks : list of callable, or None, optional (default=None)
入参是一个list,list中的对象都是callback方法。callbacks在官方文档中主要是四种方法
lightgbm.early_stopping(stopping_rounds, first_metric_only=False, verbose=True, min_delta=0.0)
lightgbm.log_evaluation(period=1, show_stdv=True)
lightgbm.record_evaluation(eval_result)
;eval_result 可以为一个空字典eval_result = {}
eval_result = {}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt],
verbose_eval=20,
callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
)
from sklearn.datasets import load_iris
import lightgbm as lgb
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import warnings
warnings.filterwarnings('ignore')
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=[i[:-5].replace(' ','_') for i in iris.feature_names])
df['target'] = iris.target
tr_x, te_x, tr_y, te_y = train_test_split(df.drop(columns='target'), df['target'], test_size=0.2)
tr_lgb_dt = lgb.Dataset(tr_x, label=tr_y.values)
te_lgb_dt = lgb.Dataset(te_x, label=te_y.values)
lgb_param = {
'objective': 'multiclass',
'metric': ['multi_logloss', 'multi_error'],
'num_class': 3,
'n_jobs': 4,
'num_iterations': 300,
'learning_rate': 0.02,
'max_depth': 4,
'lambda_l2': 0.8,
'verbose': -1
}
eval_result={}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt],
verbose_eval=20,
callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
)
# plot loss
plt.title('train_loss')
for data_name, metric_res in eval_result.items():
for metric_name, log_ in metric_res.items():
plt.plot(log_, label = f'{data_name}-{metric_name}',
color='steelblue' if 'train' in data_name else 'darkred',
linestyle=None if 'train' in data_name else '-.',
alpha=0.7)
plt.legend()
plt.show()