对CART决策树剪枝过程的理解

对CART决策树剪枝过程的理解

前言:CART决策树生成的过程比较好理解,但是剪枝的过程看了好几遍才看明白,故写出下文,供同样困惑的朋友参考。下文不涉及复杂严密的数学推导,以辅助理解为主。

一. 损失函数的定义方法

CART的损失函数用的是下式:
C α ( T ) = C ( T ) + α ∣ T ∣ (1) C_\alpha(T)=C(T)+\alpha |T| \tag{1} Cα(T)=C(T)+αT(1)
损失函数表征的是模型预测错误的程度,所以它越小越好。

上式中 C α ( T ) C_\alpha (T) Cα(T) 是关于 T T T α \alpha α 的函数, T T T 表示一个决策树, C ( T ) C(T) C(T) 是对训练数据的预测误差(分类用基尼指数表示,回归用均方误差表示), ∣ T ∣ |T| T 表示树 T T T 的叶节点个数。$\alpha $ 是一个常数,用来平衡模型对数据的拟合程度(由 C ( T ) C(T) C(T)项决定)和 模型的复杂度( α ∣ T ∣ \alpha|T| αT项决定,复杂度也就是树的分支多不多)。

如果 α \alpha α 非常小,那么损失函数 C α ( T ) C_\alpha(T) Cα(T) 的值大小由 C ( T ) C(T) C(T) 决定,为了使损失函数的值小, C ( T ) C(T) C(T) 也就会趋于小,也就是多分枝,充分延展树(因为我们生成树时,选择属性的标准就是使基尼指数或者均方误差减小的最多,所以充分分枝意味着更小的 C ( T ) C(T) C(T));

反之,如果 α \alpha α 充分大,那么损失函数 C α ( T ) C_\alpha(T) Cα(T) 的值大小由 α ∣ T ∣ \alpha |T| αT 决定,为了使损失函数的值小, ∣ T ∣ |T| T 也就会趋于小,而最小的树就是只有一个节点,所以此时剪枝成一个单节点树, ∣ T ∣ = 1 |T|=1 T=1

总而言之, α \alpha α 越大,在损失函数的影响下,模型趋向于少分枝。 α \alpha α越小,模型越趋向于多分枝。

二. 剪枝的过程

假设通过CART生成一个完整的树 T 0 T_0 T0,如下:

a

剪枝的整体思路是:

  1. 每次树所有的內结点(不是叶结点的结点,如上示树的N4,N2,N3,N7,N1),得出最适合剪枝的结点并对其剪枝,得到一个子树 T i T_i Ti ,然后再分析 T i T_i Ti 的所有內结点,找出 T i T_i Ti 最适合剪枝的结点并对其剪枝,得到子树 T i + 1 T_{i+1} Ti+1

    ⋯ \cdots

  2. 重复至最终得到的子树只剩下三个结点(一个根结点连着两个叶结点),如果这个过程中,我们得到了 k+1 个子树(注意,每次剪完枝得到的子树都要存储起来),不妨记作 { T 0 , T 1 , ⋯   , T k T_0,T_1,\cdots,T_k T0,T1,,Tk};

  3. 最后使用交叉验证,看看哪个树的性能最好,我们就选择哪个树。

核心步骤是第一步,以下给出具体解释和方法:


第一部分我们分析过: α \alpha α 越大,越趋向于多分枝; α \alpha α 越小,越趋向于少分枝。所以,必定存在一个 α \alpha α,使得分不分枝都可以(分枝与不分枝的损失函数值相同),我们记这个 α \alpha α α 0 \alpha_0 α0。所以,我们只需要依次将树的內结点和它的子节点组成的子树拿出来(比如上示树中标示出来的以 N 3 N3 N3 为根节点和以 N 4 N4 N4 为根节点的子树),计算它的 α 0 \alpha_0 α0 。对于全部的內结点,我们得到一组 α 0 \alpha_0 α0 值,然后选择其中最小的 α 0 \alpha_0 α0 对应內结点,并对其剪枝。

这句话需要稍微转个弯才能理解,为什么要选择 α 0 \alpha_0 α0 最小的结点剪枝呢?假设我们选择了一个大于 m i n ( α 0 ) min(\alpha_0) min(α0) 的值 α ′ \alpha' α 作为阈值,那么对于 剪枝阈值α0 小于 α′ 的结点,他们都处于 “趋向于不分枝“ 的状态,也就是需要剪枝,这样就会有多个结点需要剪枝,但是我们不能确保这些需要被剪枝的结点都是不相关的(剪掉一个后对另一个结点没有影响),所以我们需要控制每次只剪一个结点的枝,选择最小的 α 0 \alpha_0 α0对应的结点剪枝,就是为了控制每次只剪掉一个结点的枝,因为在损失函数是 C α ( T ) = C ( T ) + α 0 ∣ T ∣ C_\alpha(T)=C(T)+\alpha_0 |T| Cα(T)=C(T)+α0T的情况下,其他结点都处于 ”趋向于多分枝的状态“ 。

Breiman对此有严密的数学证明,感兴趣可以看看。

接下来就是确认每个內结点的 α 0 \alpha_0 α0注意,确认每个內结点的 α 0 \alpha_0 α0需要将该结点作为根节点的子树单独拿出来研究,以 N 4 N4 N4 结点为例,首先我们把它作为根节点的子树拿出来:

对CART决策树剪枝过程的理解_第1张图片

不剪枝,它的损失函数是:
C α ( T N 4 ) = C ( T N 4 ) + α ∣ T N 4 ∣ T N 4 表示以 N 4 为根节点的子树, ∣ T N 4 ∣ 表示 T N 4 的叶结点数,这里等于 2 ,但是为了得到通式,这里写为 ∣ T N 4 ∣ (2) C_\alpha(T_{N4})=C(T_{N4})+\alpha |T_{N4}|\\ T_{N4}表示以N4为根节点的子树,|T_{N4}|表示T_{N4}的叶结点数,这里等于2,但是为了得到通式,这里写为|T_{N4}| \tag{2} Cα(TN4)=C(TN4)+αTN4TN4表示以N4为根节点的子树,TN4表示TN4的叶结点数,这里等于2,但是为了得到通式,这里写为TN4(2)
剪枝后,它只剩下 N4 一个结点,光杆司令,这时候损失函数是:
C α ( N 4 ) = C ( N 4 ) + α , N 4 表示只有 N 4 这个节点的树 (3) C_\alpha(N4)=C(N4)+\alpha ,N4表示只有N4这个节点的树 \tag{3} Cα(N4)=C(N4)+α,N4表示只有N4这个节点的树(3)
找“剪不剪枝都可以的 α \alpha α” ,也就是找 C α ( T N 4 ) = C α ( N 4 ) C_\alpha(T_{N4})=C_\alpha(N4) Cα(TN4)=Cα(N4) α \alpha α 。故有
C ( T N 4 ) + α ∣ T N 4 ∣ = C ( N 4 ) + α 得到: α = C ( N 4 ) − C ( T N 4 ) ∣ T N 4 ∣ − 1 (4) C(T_{N4})+\alpha |T_{N4}|=C(N4)+\alpha \\ 得到:\alpha=\frac{C(N4)-C(T_{N4})}{|T_{N4}|-1} \tag{4} C(TN4)+αTN4=C(N4)+α得到:α=TN41C(N4)C(TN4)(4)
可得,对于任意结点 t t t,记以 t t t 为根节点的子树为 T t T_t Tt ,只有 t t t 一个结点的树直接记为 t t t ,则得到计算结点 t t t “剪不剪枝都可以的 α \alpha α” 的公式:
α = C ( t ) − C ( T t ) ∣ T t ∣ − 1 (5) \alpha=\frac{C(t)-C(T_t)}{|T_t|-1} \tag{5} α=Tt1C(t)C(Tt)(5)

问题得解:我们对每个內结点都用式 (5) 找出它”剪不剪枝都可以“ 的临界 α 0 \alpha_0 α0,然后筛选出最小的 α 0 \alpha_0 α0 对应的內结点剪枝。

三. CART 剪枝算法

输入:CART算法生成的决策树 T 0 T^0 T0

输出:最优决策树 T α T_\alpha Tα

  1. k = 0 k=0 k=0

  2. α t = + ∞ \alpha_t = +\infin αt=+

  3. 对树, T k T^k Tk各个内部节点 t t t 计算 C ( T t ) C(T_t) C(Tt) T t T_t Tt 以及
    α ( t ) = C ( t ) − C ( T t ) ∣ T t ∣ − 1 α t = m i n ( α , α ( t ) ) \alpha(t) = \frac{C(t)-C(T_t)}{|T_t|-1}\\ \alpha_t = min(\alpha,\alpha(t)) α(t)=Tt1C(t)C(Tt)αt=min(α,α(t))
    T t T_t Tt 是以t结点为根节点的子树, t t t代表结点t,也表示只有 t t t 一个 结点的树, C ( T t ) C(T_t) C(Tt) 是训练数据的预测误差(可以用基尼指数或者均方误差表征), ∣ T t ∣ |T_t| Tt t t t为根节点的子树的叶结点数。

  4. α ( t ) = α \alpha(t)=\alpha α(t)=α的内部结点 t t t 进行剪枝,对于剪枝后的结点 t t t 采用多数表决法确认其类别,得到树 T k + 1 T^{k+1} Tk+1

  5. k = k + 1 k=k+1 k=k+1

  6. 重复 3-5 ,直到 T k T^k Tk是一个三结点树(一个根节点两个叶结点)

  7. 对于得到的子树序列 T 0 , T 1 , ⋯   , T n {T_0,T_1,\cdots,T_n} T0,T1,,Tn,采用交叉验证法选出最优子树 T α T_\alpha Tα

Source:对CART决策树剪枝过程的理解

你可能感兴趣的:(ML,机器学习,决策树)