机器学习-通俗易懂推导XGBoost公式

1、前言

    xgboost是在gbdt基础上进行了升级,所以xgboost也是通过每次拟合上次的残差(上次实际值与目标值之差),从而每次生成一棵树(CART回归树),最终将所有的树加起来得到最终目标。与gbdt的区别这里先不做介绍,后续文章会涉及。

注意:XGBoost中基模型是CART树,这里的意思只是说基模型中的树是二叉树,与决策树中CART树的划分属性及划分节点的选择毫无关系,心里先有这个概念,不然后面不好理解。

2、模型推导

2.1 模型预测结果

    先给出模型最终预测结果,假设第k次生成的CART树(也可以称为残差树)为,则经过T轮之后(也就是一共有T棵树),最终模型对于样本i的预测值为(为第i个样本的输入值,T代表树的数量)

2.2 模型训练过程

    从上述公式可以看出,最终结果是对所有CART树求和。xgboost是增量学习方法,即,每一棵树都必须在上一棵树生成之后才可以继续求得,所以这里在代码设计上也只能是串行执行。
  针对每一棵树是如何生成的呢?接着分析。
  首先,每次的训练的目标是使预测值最接近真实值(即损失函数最小化)。损失函数常用下式表示:

公式说明:其中表示对n个样本的损失值求和;表示与预测值及真实值有关的损失函数(该函数中真实值是已知的,但是预测值随着树模型的不同而在变换,所以变量是)。

  但是xgboost不是直接最小化上述损失函数作为训练的目标,而是要在上述公式基础上加上树的复杂度。为什么要加树的复杂度?就是为了避免过拟合(在训练集上表现很好,但是在测试集或待预测数据上表现很差),这里的复杂度就是常说的正则项
所以,最后的目标函数就变为:

注:其中树复杂度是k棵树复杂度的求和。在模型求第k棵树时,其实前k-1棵树的复杂度已经知道了。
所以,得到第t棵树之后,总的损失值变为:

公式说明:
其中,第t棵树时,总模型的预测值是前t-1棵树的预测值加上第t棵树的预测值(代表第t棵树):

第t棵树时,模型总复杂度的值是前t-1棵树的复杂度加上第t棵树的复杂度

注意:由于训练第t棵树时,前面的t-1棵树已经全部训练好了,其参数已经确定,所以前面树的复杂度也已经确定,这里用const表示。
目标函数变为:

目标函数中的损失函数部分(第一部分):

  二阶泰勒公式:

  我们把当作,当作,则损失函数的第一部分变为:

\approx \sum_{i=1}^{n}\left[l\left(y_{i}, \hat{y}_{i}^{t-1}\right)+\frac{\partial l\left(y_{i}, \hat{y}_{i}^{t-1}\right)}{\partial \hat{y}_{i}^{t-1}} f_{t}\left(x_{i}\right)+\frac{1}{2} \frac{\partial^{2} l\left(y_{i}, \hat{y}_{i}^{t-1}\right)}{\partial\left(\hat{y}_{i}^{t-1}\right)^{2}}\left(f_{t}\left(x_{i}\right)\right)^{2}\right]

公式说明:
(1)针对上式,一定会有人对 产生疑惑,为什么将当作后,函数中怎么还有?
这里注意,损失函数中的是样本的真实值,所以该值是常量,函数只是与有关,但是在求第t棵树时,前面t-1棵树已经确定了,所以就是前t-1棵树与真实值的一个误差值,是一个常量;
(2)公式中,到底什么含义?
其中,也就是只是损失函数对预测值(上面已经提过,预测值是变量)求导后,将第t棵树之前的预测值(前面所有树预测值的和)带进去之后的结果;如果还不理解,借助网上一个例子:

求例子:
在一个分类任务中,假设前5棵树已经知道,现在要确定第6个树,损失函数为:

现有一个样本,真实标签为1,但是前5棵树预测的值为-1,由于,所以对于该样本,损失函数变为:

则,
就应该是上式对求导后,将前5棵树总预测值-1带入求导后公式,得到的结果值:
求导得:

令上式中=-1,带入后求得-0.27,也就是。
同理,是二次求导的结果。

通过以上分析可知:
(1)对于N个样本,则会有N个与需要求;
(2)每一个样本可以求得一个与一个,所以这里可以以并行的方式求取各个样本对应的与(这也是XGBoost快的原因之一);
(3)对于每棵树的求取过程中,涉及到对损失函数的二次求导,所以XGBoost可以自定义损失函数,只需损失函数二次可微即可。
此时,目标函数变为:

公式说明:因为就是前t-1棵树与真实值的一个误差值,是一个常量,故最小化目标函数可以将其移除。

目标函数中树复杂度,正则项(第二部分):

  上式中,到底是什么?通过前文分析,就是代表一棵树模型,设想一下,树模型就是给定一个输入样本,经过模型之后会将其划分到某个叶子节点上,并会给出该叶子节点预测值。所以可以用下式表示:

公式说明:代表一个样本输入之后会划分到哪个叶子节点上(可以理解为树的结构),就是一个样本经过树模型之后具体预测值。
正则项:XGBoost定义正则项为:

公式说明:其中代表一棵树中第个叶子节点的预测值;T代表一棵树共有T个叶子节点;与是自定义的值,在使用模型时可以设置,如果大,则树的叶子节点数越多,则惩罚越大,则会惩罚叶子节点总的预测值(理想情况下,模型是希望一步一步慢慢去逼近真实值,而不是一步逼近太多,导致后面可逼近范围太小)。
特别说明:上述公式是针对所有样本计算的结果。

目标函数再次升级(关键时刻):

  通过对两部分分析,得到:
\begin{aligned} O b j^{(t)} & \approx \sum_{i=1}^{n}\left[g_{i} w_{q\left(x_{i}\right)}+\frac{1}{2} h_{i} w_{q\left(x_{i}\right)}^{2}\right]+\gamma T+\frac{1}{2} \lambda \sum_{j=1}^{T} w_{j}^{2} \\ &=\sum_{j=1}^{T}\left[\left(\sum_{i \in I_{j}} g_{i}\right) w_{j}+\frac{1}{2}\left(\sum_{i \in I_{j}} h_{i}+\lambda\right) w_{j}^{2}\right]+\gamma T \end{aligned}
公式说明:一定有人会慌了,不明白为什么,不要慌,仔细看下面分析。(1)公式中n表示共有n个样本,T表示待求树模型共有T个叶节点;(2)表示样本经过待求树模型之后的预测值,代表一棵树中第个叶子节点的预测值;(3)表示每一个样本预测值乘以对应的之后的求和(这里在求和过程中会涉及到每一个叶子节点);
所以公式转换可以理解为:{先对每一个样本求预测值乘以对应的(即),然后对所有的样本求和(即)}转变为{先求所有样本在第个叶子节点上的预测值,然后乘以所有样本在j节点上的和,最后对所有节点求和}
本人能力有些,对于这部分暂时还没有想到更好的方法来解释,后续想到会持续更新,如果读者有什么好的解释可以告知,感谢!

对上述目标函数再一次升级,得:

公式说明:其中与就是在每个叶子节点上与在各个样本上的求和。

坚持一下,马上结束了。
  我们的最终目的就是使上述的目标函数最小化,则对上式求导(上式对变量求偏导),并使其得0,得:

所以:

带入原式,得:

公式说明:最终的只是表示一棵树的预测结果距离真实值的距离,与其他值无关(比如与每个叶子节点具体取值无关)。
  通过上式,便可以求出树的每个叶子节点的具体值,但是到这里结束了吗?还没有,还需要求每棵树到底应该以哪个属性(特征)的哪个分割点进行划分。
  在决策树中是怎么划分的呢?可以参考我的另一篇文章文章。总的来看,就是选择一个使得划分之后的数据纯度提升最大的属性及分割点。XGBoost是选择使得划分之后损失减小最大(损失值的增益Gain)的属性及分割点。
举例说明:
  例子来源陈天奇大师论文,有兴趣读者可以自行查阅相关资料,这里就不再详细介绍该例子。
  如果对于一棵树,现在要求属性为年龄的最佳分割点,如下图:

年龄的分割点

从图中可以看出,假设以竖线为分割线,则:
分割前的损失值为:

其中,T=1,划分之前只有一个叶子节点。
分割后的损失值为:

(-\frac{1}{2}*\frac{G_{L}^{2}}{H_{L}+\lambda})+(-\frac{1}{2}*\frac{G_{R}^{2}}{H_{R}+\lambda})+\gamma T =-\frac{1}{2}\left[\frac{G_{L}^{2}}{H_{L}+\lambda}+\frac{G_{R}^{2}}{H_{R}+\lambda}\right]+\gamma T

其中T=2,因为划分之后就变成两个叶子节点。
则用分割前损失减去分割后损失值,得:

公式说明:该公式代表根据划分属性及划分节点对样本划分之后其损失值减小了多少,所以每次选择最大的Gain值对应的划分属性及划分节点即可,这里求计算各个划分属性及节点是并行计算的。

3、最后说明

(1)文章开篇已经说过,这里再次强调。XGBoost中基模型是CART树,这里的意思只是说基模型中的树是二叉树,要与决策树中CART树的划分属性及划分节点的选择毫无关系;
(2)模型中既然是对每次预测结果求和,则说明每次预测结果是一个回归值,而不是类别值,不然求和是无意义的。

以上内容如有理解不当,请指出,谢谢!另,文章中有些内容来源于一些书籍或其他博客,这里就不一一列举,如有侵权,请与我联系删除。

你可能感兴趣的:(机器学习-通俗易懂推导XGBoost公式)