运用XGBoost识别鸢尾花所属类别

运用XGBoost识别鸢尾花所属类别_第1张图片
图片.png

一、数据集介绍

在scikit-learn中内置有一些小型标准数据集,运用这些数据集我们就不需要从某个外部网站或者本地目录下加载任何文件了。下面我们要通过sklearn自带的鸢尾花数据集训练一个基于xgboost的鸢尾花分类模型。

二、运用xgboost对鸢尾花数据集进行分类

  • 1、加载相关包


    运用XGBoost识别鸢尾花所属类别_第2张图片
    图片.png

    这里我们主要用到xgboost自带的接口,当然你也可以通过如下命令使用sklearn接口:from xgboost.sklearn import XGBClassifier 。
    XGBClassifier是xgboost的sklearn包。通过这个包我们可以使用栅格搜索(Grid Search)和并行处理。

  • 2、拆分数据集
    加载完相关的包我们需要从sklearn中加载鸢尾花数据集,并将数据集按80%为训练集,20%为测试集进行拆分。


    运用XGBoost识别鸢尾花所属类别_第3张图片
    图片.png

    当我们执行程序是会发现在右上角显示变量的窗口中出现了iris变量


    运用XGBoost识别鸢尾花所属类别_第4张图片
    图片.png

    点击变量得value我们就可以看到关于iris的相关信息,该窗口展示了变量的5个元素
    运用XGBoost识别鸢尾花所属类别_第5张图片
    图片.png

    可以看到这5个元素分别是:

    1、DESCR(数据集的相关描述)


    运用XGBoost识别鸢尾花所属类别_第6张图片
    图片.png

    2、data(数据集中特征列的取值)
    运用XGBoost识别鸢尾花所属类别_第7张图片
    图片.png

    3、feature_names(特征的名称)
    运用XGBoost识别鸢尾花所属类别_第8张图片
    图片.png

    4、target(标签列的取值)
    运用XGBoost识别鸢尾花所属类别_第9张图片
    图片.png

    5、target_names(标签列每种类别所指代的的名称)
    运用XGBoost识别鸢尾花所属类别_第10张图片
    图片.png

    我们通过变量窗口显示的变量信息可以对数据集有一个很好的了解。
  • 3、模型参数设置


    运用XGBoost识别鸢尾花所属类别_第11张图片
    图片.png

    运用XGBoost识别鸢尾花所属类别_第12张图片
    图片.png
  • 4、将特征和标签转化为DMtrix的格式
    我们数据集的特征列和标签传入xgb.DMtrix()中,会将数据变成DMtrix的格式,这是一个XGBoost自己定义的数据格式(就像numpy中有ndarray, pandas中有dataframe数据格式一样)。这个格式会将第一列作为label,其余的列作为features。


    图片.png
  • 5、训练模型


    运用XGBoost识别鸢尾花所属类别_第13张图片
    图片.png

XGBoost基本方法和默认参数

在训练过程中主要用到两个方法:xgboost.train()和xgboost.cv()
这里我们用到了xgboost.train()

xgboost.train(params,dtrain,num_boost_round=10,evals(),obj=None,
feval=None,maximize=False,early_stopping_rounds=None,
evals_result=None,verbose_eval=True,learning_rates=None,
xgb_model=None)

参数说明:

params:这是个字典,里面包含着训练中的参数关键字和对应的值
dtrain:用于训练的数据
num_boost_round:这个参数用于指定提升树迭代的个数
evals:这是一个列表,用于对训练过程中进行评估列表中的元素。形式是evals=[(dtrain,”train”),(dval,”val”)],或者是evals=[(dtrain,”train”)],对于第一种情况,它使得我们可以在训练过程中观察验证集的效果。
obj:该参数用于自定义目标函数
feval:该参数用于自定义评估函数
maximize:是否对评估函数进行最大化
early_stopping_rounds:该参数用于指定早期停止次数,假设为100,验证集的误差迭代到一定程度在100次内不能再继续降低就停止迭代。这要求evals中至少有一个元素。如果有多个,就按最后一个去执行。返回的是最后的迭代次数(不是最好的)。如果该参数存在,则模型会生成三个属性,bst.best_score,bst_best_iteration,和bst.best_ntree_limit.
evals_result:字典,存储在watchlist中元素的评估结果。
verbose_eval:该参可以为布尔型或者数值型,也要求evals中至少有一个元素,如果为True,则对evals中元素的评估结果会输出在结果中;如果该参数的值为数字,假设为5,则每隔5个迭代输出一次。
learning_rates:每一次提升的学习率的列表
xgboost_model:在训练之前用于加载xgboost model

运行结果:

运用XGBoost识别鸢尾花所属类别_第14张图片
图片.png

我们可以看到当模型迭代到第18次的时候停止了,这时候我们得到的模型的三个属性输出如下所示:


图片.png
  • 6、预测


    运用XGBoost识别鸢尾花所属类别_第15张图片
    图片.png

    我们用该模型对数据集进行预测,并用准确率作为评估模型性能的指标,并输出预测结果的准确率和混淆矩阵。运行结果如下所示:


    运用XGBoost识别鸢尾花所属类别_第16张图片
    图片.png

    可以看到在测试集上模型的准确率为100%。
  • 7、绘制数据集的特征重要性图表


    图片.png

运行结果:

运用XGBoost识别鸢尾花所属类别_第17张图片
图片.png

参考资料:

  • http://xgboost.readthedocs.io/en/latest/python/python_api

你可能感兴趣的:(运用XGBoost识别鸢尾花所属类别)