机器学习_LightGBM callback示例

lightgbm在train的时候有callback的接口,我们需要将训练过程的损失下降情况进行记录就需要这个接口。本文笔者就是以记录训练迭代过程的损失为出发点,写一个简单的lightgbm中callback的使用方法。

一、callbacks接口

callbacks参数输入要求

(function) train: (params:
… callbacks: List[(…) -> Any] | None = None) -> Booster
callbacks : list of callable, or None, optional (default=None)
入参是一个list,list中的对象都是callback方法。callbacks在官方文档中主要是四种方法

  • early_stopping
    • 停止迭代
    • lightgbm.early_stopping(stopping_rounds, first_metric_only=False, verbose=True, min_delta=0.0)
  • log_evaluation
    • 记录迭代过程的指标, 可以在日志中输出
    • lightgbm.log_evaluation(period=1, show_stdv=True)
  • record_evaluation(eval_result)
    • 把迭代过程指标记录到输入的空字典中
    • lightgbm.record_evaluation(eval_result) ;eval_result 可以为一个空字典eval_result = {}
  • reset_parameter(**kwargs)
    • 每次迭代更新数据
    • List of parameters for each boosting round or a callable that calculates the parameter in terms of current number of round

示例

eval_result = {}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt], 
          verbose_eval=20,
          callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
          )

二、完整iris案例


from sklearn.datasets import load_iris
import lightgbm as lgb
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import warnings
warnings.filterwarnings('ignore')


iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=[i[:-5].replace(' ','_') for i in iris.feature_names])
df['target'] = iris.target

tr_x, te_x, tr_y, te_y = train_test_split(df.drop(columns='target'), df['target'], test_size=0.2)
tr_lgb_dt = lgb.Dataset(tr_x, label=tr_y.values)
te_lgb_dt = lgb.Dataset(te_x, label=te_y.values)


lgb_param = {
    'objective': 'multiclass',
    'metric': ['multi_logloss', 'multi_error'],
    'num_class': 3,
    'n_jobs': 4,
    'num_iterations': 300,
    'learning_rate': 0.02,
    'max_depth': 4,
    'lambda_l2': 0.8,
    'verbose': -1
}
eval_result={}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt], 
          verbose_eval=20,
          callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
          )


# plot loss
plt.title('train_loss')
for data_name, metric_res in eval_result.items():
    for metric_name, log_ in metric_res.items():
        plt.plot(log_, label = f'{data_name}-{metric_name}', 
                color='steelblue' if 'train' in data_name else 'darkred', 
                linestyle=None if 'train' in data_name else '-.',
                alpha=0.7)

plt.legend()
plt.show()

机器学习_LightGBM callback示例_第1张图片

你可能感兴趣的:(机器学习,机器学习,python,人工智能,lightgbm)