摘要:GBDT是一种性能非常好的机器学习模型框架,在产业界应用十分广泛。本文介绍了集成模型的3种常见集成策略,即bagging、stacking和boosting,并对boosting模型的杰出代表GBDT的基本原理进行了介绍。用于实战的完整版GBDT是比较复杂的,为了降低学习曲线,这里首先介绍了GBDT的核心,即如何基于目标函数得到每棵CART的拟合目标,然后列举了常见的用于提升GBDT效果的策略。
1. 引言
大概是在2018年,我当时的领导兰博,接了一个项目,然后让我们组的leader韩老师负责。韩老师简单盘算了几秒钟,然后然我了解一下“GBDT”。我感觉没有听清楚,就和韩老师确认了好几回,最后确认确实是“GBDT”。接下来,我就开始网上冲浪,搜索GBDT相关的资料,知道了它的全称是“梯度提升决策树树”(Gradient Boosting Decision Tree)。不过,很快这个项目就不了了之,我也没有继续学习GBDT。
2019年初有段时间,我想温习一下数据挖掘的东西,就在阿里云的天池里找练手项目,最后锁定了“快来挖掘幸福感!”(天池大数据众智平台-阿里云天池tianchi.aliyun.com
)。这个题目(看起来)比较简单,内容带有正能量,很难不选它。一开始我用的是多元线性回归,手工制造了一大堆特征。虽然累了个半死,但是模型效果很差,排名很靠后。我不服气啊,就着手优化模型——先后尝试了回归树、随机森林回归、GBDT等模型(python的第三方库sklearn里有现成的),发现GBDT效果异常的好,排名最高的时候到了前十。
按照算法工程师的职业习惯,遇到一个效果贼好的算法或者模型,一定要搞清楚原理才会收心。这是我整理这篇文章的主要原因。GBDT是一种集成模型,采用的基础模型是CART回归树,因此,这里首先简单介绍一下集成模型和CART回归树,最后展示GBDT。
在我看过的几篇资料中,最好的可能是这个了https://machinelearningmastery.com/gentle-introduction-gradient-boosting-algorithm-machine-learning/machinelearningmastery.com
2. 众人拾柴火焰高——集成模型
几乎所有的数据科学类任务都可以抽象为基于已知数据x预测未知取值的变量y的取值,即刻画一个函数f,使:
当然,暂时还没有那么厉害的模型,可以做到完全准确的预测,预测值和真实值总是存在一定的差异:
其中residual被称为残差。我们评价一个模型的预测能力,一般考察残差的两个方面:(1)偏差,即与真实值分布的偏差大小;(2)方差,体现模型预测能力的稳定性,或者说鲁棒性。
GBDT采用了多模型集成的策略,针对残差进行拟合,进而降低模型的偏差和方差。
集成模型(Ensemble Model)不是一种具体的模型,而是一种模型框架,它采用的是“三个臭皮匠顶个诸葛亮”的思想。集成模型的一般做法是,将若干个模型(俗称“弱学习器”, weak learner,基础模型,base model)按照一定的策略组合起来,共同完成一个任务——特定的组合策略,可以帮助集成模型降低预测的偏差或者方差。图 2-1 集成模型的思想示意图
常见的集成策略有bagging、stacking、boosting。
2.1. Bagging模型
就像它的名字一样,bagging模型会把(相同的)若干基础模型简单的“装起来”——基础模型独立训练,然后将它们的输出用特定的规则综合(比如求平均值)起来,形成最后的预测值。最常见的bagging模型是随机森林(Random Forest)。在回归类任务中,bagging模型假设各个基础模型的预测值“错落有致”,分布在真实值的周围——把这些预测值平均一下,就可以稳定地得到一个比较准确的预测值;而在分类任务中,bagging模型认为“每一个基础模型都判断错误”发生的概率是比较低的,基础模型中的相当一部分会做出正确的判断,因此可以基于大家的投票结果来选择最终类别。图 2-2 bagging模型的基本结构
2.2. Stacking模型
Stacking模型比bagging模型更进2步:(a)允许使用不同类型的模型作为base model;(b)使用一个机器学习模型把所有base model的输出汇总起来,形成最终的输出。(b)所述的模型被称为“元模型”。在训练的时候,base model们直接基于训练数据独立训练,而元模型会以它们的输出为输入数据、以训练数据的输出为输出数据进行训练。Stacking模型认为,各个基础模型的能力不一,投票的时候不能给以相同的权重,而需要用一个“元模型”对各个基础模型的预测值进行加权。图 2-3 stacking模型的基本结构
2.3. Boosting模型
Boosting模型采用另一种形式,把基础模型组合起来——串联。这类模型的思想是,既然一个基础模型可以做出不完美的预测,那么我们可以用第二的基础模型,把“不完美的部分”补上。我们可以使用很多的基础模型,不断地对“不完美的部分”进行完善,以得到效果足够好的集成模型。Boosting的策略非常多,以GBDT为例,它会用第K个CART拟合前k-1个CART留下的残差,从而不断的缩小整个模型的误差,如图2-4。
图 2-4 Boosting Tree的结构
3. 用于回归任务的GBDT
GBDT的名字里有“梯度”和“提升”两个词语。一般来说,我们首先接触的是“梯度下降”,看到这里的“梯度上升”时会疑惑——不用疑惑,这里的“梯度”和“提升”没有直接关系:“梯度”被用来让损失函数快速下降,进而让模型效果“提升”。
GBDT的基础模型是CART,该模型的的相关内容可以参考https://zhuanlan.zhihu.com/p/128472955zhuanlan.zhihu.com
3.1. 回归任务的描述
回归问题可以这样表述:
给定一个数据集
。其中
是任务的输入数据,为第n个样本的特征,是一个K维向量;
为第n个样本的输出值。回归任务就是,构建一个模型
,以最小的误差,基于特征预测输出值
3.2. 目标函数与构建方式
假设GBDT里有K个CART,其中第k个CART记为
。前k个CART的预测值记为
——GBDT是一种加法模型,它把所有基础模型的预测值累加起来作为最终的预测值。为了便于理解和计算,我们有时候会把前k个CART的预测值表示为一个递归形式:
在训练第k个CART的时候,我们需要最小化这样一个目标函数:
可以使用梯度下降的方法,让目标函数的取值尽快降下来。我们需要
的更新规则。
目标函数对
的梯度是:
沿着这个方向走出一小步,就可以得到一个新的函数,其相应的损失函数值更小。这里涉及了一点泛函数的内容(可以参考网友“小腹黑zju”的博客[原创]理解泛函的概念和能量泛函的梯度下降流_小腹黑zju_新浪博客blog.sina.com.cn
;更深的我也母鸡啦)。
我们可以用类似梯度下降法优化模型参数的形式,表示这一想法:
注意,由式(3-1)和式(3-2)可以得到:
这个式子的意思是,第k个CART的拟合对象为目标函数(对前k-1个CART输出值)的负梯度,可以保证目标函数较快下降。这里为了简单,假定学习率为1。
一般来说,回归任务里,我们采用残差平方和作为目标函数:
因此有:
式(3-4)的意思是,GBDT的每一棵CART的任务,是拟合之前所有CART留下的残差。
3.3. GBDT构建方式的直白表述
基础模型是串联结构,其中第一个CART的输出是
,对应的残差是
。此时,我们的强模型是
第2个CART的任务是构建
的关系,即
。其中
是该基础模型对
进行预测时的残差
。此时,我们的强模型是
以此类推,直到每一个基础模型都训练完毕。
假设最后的提升树中有K个CART,那么使用它进行预测的规则就是:
3.4. GBDT的python实现
我用python实现了一个极简版本的、用于回归任务的GBDT,可见于https://github.com/lipengyuer/DataScience/blob/master/src/algoritm/GBDTRegression.pygithub.com
里面使用的CART则来自https://github.com/lipengyuer/DataScience/blob/master/src/algoritm/CARTRegression.pygithub.com
3.5. 如何让GBDT更好
一个实用的GBDT需要在前面所述的基础上做很多改进:
(1) 在预测阶段,每个CART是独立的,因此可以并行计算。另外,得益于决策树的高效率,GBDT在预测阶段的计算速度是非常快的。
(2) 在训练阶段,GBDT里的CART之间存在依赖,无法并行,所以GBDT的训练速度是比较慢的。人们提出了一些方法,用来提升这个阶段的并行度,以提升学习速度。
(3) 这里为了简单,使用了残差平方和作为损失函数,实际上还可以使用绝对值损失函数、huber损失函数等,从而让GBDT在鲁棒性等方面得到提升。
(4) GBDT的学习能力非常强,容易过拟合。大部分时候,我们都会给目标函数添加针对模型复杂度的惩罚项,从而控制模型复杂度
(5) 等等。
4. 结语
GBDT里,我们需要将目标函数对另一个函数(即前k-1个CART组成的模型)求偏导,进而基于梯度得到一棵CART(即第k课CART)的学习目标——这是理解GBDT的主要难关,其他的都是小问题。
整理GBDT的过程让我意识到,“良好数学基础”对于理解和应用机器学习确实是十分必要的。
注意:本文为李鹏宇(知乎个人主页https://www.zhihu.com/people/py-li-34)原创作品,受到著作权相关法规的保护。如需引用、转载,请注明来源信息:(1)作者名,即“李鹏宇”;(2)原始网页链接,即当前页面地址。如有疑问,可发邮件至我的邮箱:[email protected]。
参考文献
[1] GBDT入门资料https://machinelearningmastery.com/gentle-introduction-gradient-boosting-algorithm-machine-learning/machinelearningmastery.com
[2] CART树的原理https://zhuanlan.zhihu.com/p/128472955zhuanlan.zhihu.com
[3] 泛函与泛函的导数[原创]理解泛函的概念和能量泛函的梯度下降流_小腹黑zju_新浪博客blog.sina.com.cn