5-5 决策树的剪枝算法

树的剪枝算法

输入:
ID3或C4.5的决策树
参数a
输出:
剪枝后的决策树 T a T_a Ta

递归版本

  1. 从树的根结点开始
  2. 如果该结点的孩子中存在子树(不全是叶子结点),则先对子树做prune
  3. 所有子树都prune之后,再判断该结点的孩子是否都是叶子
  4. 如果不全是叶子,对该结点的算法结束
  5. 如果该结点的孩子都是叶子,则尝试对该结点剪枝
    5.a 计算 C a ( T B ) C_a(T_B) Ca(TB),代表该结点split后以该结点为根结点的子树的损失,其损失的计算方式与整棵树的计算方式相同
    C a ( T B ) = ∑ ∣ T ∣ N t H t ( T ) + a ∣ T ∣ C_a(T_B) = \sum^{|T|}N_tH_t(T) + a|T| Ca(TB)=TNtHt(T)+aT
    因此在生成树的时候为每个叶子结点保存它的 N t N_t Nt H t H_t Ht
    5.b 计算 C a ( T A ) C_a(T_A) Ca(TA),代表该结点split之前作为一个叶子结点时的熵。因此在生成树时记录该split之前的熵
    5.c 比较 C a ( T B ) C_a(T_B) Ca(TB) C a ( T A ) C_a(T_A) Ca(TA),如果 C a ( T A ) ≤ C a ( T B ) C_a(T_A)\le C_a(T_B) Ca(TA)Ca(TB),则将该结点修改为叶子结点。即:计算该结点的输出标记,并删除它的所有孩子结点。
  6. 对该结点的处理结束,由于这个算法是递归调用的,如果该结点有父结点,则要继续处理它的父结点。
def isTree(Node):
    return 'child' in Node.keys()

def Clip(Node):
    bestNt = 0
    for value in Node['child']:
        if Node['child'][value]['Nt'] > bestNt:
            bestNt = Node['child'][value]['Nt']
            bestLabel = Node['child'][value]['label']
    Node['label'] = bestLabel
    Node.pop('child')
    
def Merge(Node, alpha):
    # 计算CaTb
    CT_b = 0
    for value in Node['child']:
        CT_b = CT_b + Node['child'][value]['Nt'] * Node['child'][value]['entropy'] + alpha
    # 计算CaTa
    CT_a = Node['entropy'] + alpha
    # 剪枝的条件
    if CT_a <= CT_b:
        Clip(Node)
        
def prune(Node, alpha):
    # 判断子结点中是否存在树
    for value in Node['child']:
        # 如果存在树
        if isTree(Node['child'][value]):
            # 先对树子结点做prune
            prune(Node['child'][value], alpha)
    # 对所有子树都prune之后,判断是否所有子树都是叶子
    isAllLeaf = True
    for value in Node['child']:
        if isTree(Node['child'][value]):
            isAllLeaf = False
            break
    # 如果所有子树都是叶子
    if isAllLeaf:
        # 尝试对结点做剪枝
        Merge(Node, alpha)

DP版本

【?】如何通过DP实现

你可能感兴趣的:(李航,-,统计学习方法)