决策树——预剪枝和后剪枝

目录

简析

为什么要剪枝?

剪枝的基本策略

预剪枝

后剪枝

剪枝的优缺点

预剪枝的优缺点

后剪枝的优缺点

实现

数据集

剪枝前

预剪枝

分析

代码


简析

为什么要剪枝?

“剪枝”是决策树学习算法对付 “过拟合” 的主要手段
可通过“剪枝”来一定程度避免因决策分支过多,以致于把训练集自身的一些特点当做所有数据都具有的一般性质而导致的过拟合

剪枝的基本策略

预剪枝

通过提前停止树的构建而对树剪枝,主要方法有:
1.当决策树 达到预设的高度 时就停止决策树的生长
2.达到某个节点的实例 具有相同的特征向量 ,即使这些实例不属于同一类,也可以停止决策树的生长。
3.定义一个阈值,当达到某个节点的 实例个数小于阈值 时就可以停止决策树的生长。
4.通过 计算每次扩张对系统性能的增益 ,决定是否停止决策树的生长。

后剪枝

先从训练集生成一棵完整的决策树,然后 自底向上 地对非叶结点进行分析计算,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。
主要方法有:

1.悲观剪枝(PEP)

2.最小误差剪枝(MEP)

3.错误率降低剪枝(REP)

4.代价复杂度剪枝(CCP)

5.OPP (Optimal Pruning)

6.CVP (Critical Value Pruning)

剪枝的优缺点

预剪枝的优缺点

•优点
降低过拟合风险
显著减少训练时间和测试时间开销。
•缺点
欠拟合风险 :有些分支的当前划分虽然不能提升泛化性能,但在其基础上进行的后续划分却有可能显著提高性能。预剪枝基于“ 贪心 ”本质禁止这些分支展开,带来了欠拟合风险。

后剪枝的优缺点

•优点
后剪枝比预剪枝保留了更多的分支, 欠拟合风险小 泛化性能往往优于预剪枝决策树
缺点
训练时间开销大 :后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶结点逐一计算

实现

创建决策树部分在上篇博客中已经实现,这边不再复述。https://blog.csdn.net/qq_51994140/article/details/127850274?spm=1001.2014.3001.5501

数据集

根据学生获奖情况(0表示没有,1表示省级,2表示国家级),刷笔试题和面试题情况,和实习经历,来判断是否能找到工作(N表示不能,Y表示能)。

之前的数据集存在问题,已经修改,如下。

训练集

获奖情况 刷题情况 实习经历 工作

2

1 1 Y
2 1 0 N
2 0 1 Y
2 0 0 N
1 1 Y
1 1 0 N
1 0 1 Y
1 0 0 N
0 1 1 Y
0 1 0 N
0 0 1 N
0 0 0 N

验证集

获奖情况 刷题情况 实习经历 工作
2 1 1 Y
2 1 0 N
1 0 0 N
0 0 0 N

剪枝前

决策树——预剪枝和后剪枝_第1张图片

预剪枝

分析

基于信息增益原则,选取属性实习经历划分训练集,分别计算划分前和划分后的验证集精度,判断是否需要划分

1.结点1不划分,将其标记为叶结点,类别标记为N,验证集中仅有一条数据分类正确,验证集精度为(3/4)*100%=75%.

 结点1划分,则情况如下图

决策树——预剪枝和后剪枝_第2张图片

此时验证集中全部样例划分正确,验证集精度为100%,因此选择划分

2.结点2为叶子结点,禁止划分

3.结点3:选取‘获奖情况’进行划分,情况如下图:

决策树——预剪枝和后剪枝_第3张图片

测试集中仅有两条数据符号,验证集精度为(2/4)*100%=50%,精度没有提高,因此选择不继续划分。

4.结果:

 决策树——预剪枝和后剪枝_第4张图片 决策树——预剪枝和后剪枝_第5张图片

代码

def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method='id3'):
 
    trainData = np.asarray(dataTrain)
    labelTrain = np.asarray(labelTrain)
    labelTest = np.asarray(labelTest)
    names = np.asarray(names)

    # 如果结果为单一结果
    if len(set(labelTrain)) == 1:
        return labelTrain[0]
        # 如果没有待分类特征
    elif trainData.size == 0:
        return voteLabel(labelTrain)
    # 其他情况则选取特征
    bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method=method)
    # 取特征名称
    bestFeatName = names[bestFeat]
    # 从特征名称列表删除已取得特征名称
    names = np.delete(names, [bestFeat])
    # 根据最优特征进行分割
    dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)
 
    # 预剪枝评估
    # 划分前的分类标签

    labelTrainLabelPre = voteLabel(labelTrain)
    labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
    # 划分后的精度计算
    if dataTest is not None:
        dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
        # 划分前的测试标签正确比例
        labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
        # 划分后 每个特征值的分类标签正确的数量
        labelTrainEqNumPost = 0
        for val in labelTrainSet.keys():
            labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
        # 划分后 正确的比例
        labelTestRatioPost = labelTrainEqNumPost / labelTest.size
 
        # 如果没有评估数据 但划分前的精度等于最小值0.5 则继续划分
    if dataTest is None and labelTrainRatioPre == 0.5:
        decisionTree = {bestFeatName: {}}
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue),
                                                                         labelTrainSet.get(featValue)
                                                                         , None, None, names, method)
    elif dataTest is None:
        return labelTrainLabelPre
        # 如果划分后的精度相比划分前的精度下降, 则直接作为叶子节点返回
    elif labelTestRatioPost < labelTestRatioPre:
        return labelTrainLabelPre
    else:
        # 根据选取的特征名称创建树节点
        decisionTree = {bestFeatName: {}}
        # 对最优特征的每个特征值所分的数据子集进行计算
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue),
                                                                         labelTrainSet.get(featValue),
                                                                         dataTestSet.get(featValue),
                                                                         labelTestSet.get(featValue),names,method)
    return decisionTree

你可能感兴趣的:(决策树,剪枝)