决策树 基于python实现ID3,C4.5,CART算法

实验目录

  • 实验环境
  • 简介
    • 决策树(decision tree)
    • 信息熵
    • 信息增益(应用于ID3算法)
    • 信息增益率(在C4.5算法中使用)
    • 基尼指数(被用于CART算法)
  • 实验准备
    • 数据集
    • 算法大体流程
  • 实验代码
    • 训练集数据读入
    • 信息熵代码
    • 算法流程结构(ID3和C4.5部分)
    • CART算法
    • 可视化
    • 剪枝

实验环境

Python:3.7.0
Anconda:3-5.3.1 64位
操作系统:win10
开发工具:sublime text(非必要)

简介

决策树(decision tree)

:是一种基本的分类与回归方法,此处主要讨论分类的决策树。

在分类问题中,表示基于特征对实例进行分类的过程,可以认为是if-then的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。

决策树通常有三个步骤:特征选择决策树的生成决策树的修剪

用决策树分类:从根节点开始,对实例的某一特征进行测试,根据测试结果将实例分配到其子节点,此时每个子节点对应着该特征的一个取值,如此递归的对实例进行测试并分配,直到到达叶节点,最后将实例分到叶节点的类中。

下图即是一个经典的决策树模型示意图——关于是否出去玩

决策树 基于python实现ID3,C4.5,CART算法_第1张图片
事实上决策树原理和近些年网络流行的网络天才一类的问答猜测结果游戏相似,根据一系列数据,然后给出游戏的答案。
决策树 基于python实现ID3,C4.5,CART算法_第2张图片

与上次实验所涉及的k-近邻算法相比,knn可以完成很多分类任务,但是其最大的缺点是无法给出数据的内在含义。

信息熵

在决策树生成节点的过程中,势必要经过特征的选择这一步;而这里会涉及到一个新的概念用以区分将哪些更有价值的信息放置在树的上层(以求使得树的结构相对更加简便)。

信息熵本身的意义为对信息的一种度量。物品可以用重量度量,长度可以用尺子度量。那信息用什么度量呢?《机器学习实战》这本书的信息量是多少呢?用什么度量呢?直到1948年香农提出了“信息熵”的概念,才解决了对信息的量化度量问题。信息熵是消除不确定性所需信息量的度量。一件事情的信息熵越高说明它需要的信息越多,来消除它的不确定性,附公式如下:
在这里插入图片描述

信息增益(应用于ID3算法)

介绍过信息熵后,得以进一步追加一个新的概念用以构造决策树,即信息增益。顾名思义添加了信息之后能增加多少收益。也就是说增加信息之后能减少多少不确定性。

信息增益的计算方式为信息增益=熵-条件熵
在这里插入图片描述

g(X,A)=H(X)-H(X|A)。由于特征A而使得对数据D的分类的不确定性减少的程度。显然,对于数据集而言,信息增益依赖于特征,不同的特征往往具有不同的信息增益,信息增益大的特征具有更强的分类能力。

信息增益率(在C4.5算法中使用)

上面的信息增益存在着一个重大的问题,即其本身对于那些可以取到更多可能值的属性有着更高的偏好,假设存在一个类似于号码的属性,那么由于每个实例的号码属性都不相同,那么信息增益的方式就会倾向于直接以号码为判断条件;但这样的过程显然是不合逻辑的,因为号码会随着数据的更换失去意义,也就是这时结构发生了严重的过拟合现象。

那么这时就引入信息增益的改进算法,即信息增益率;其定义式如下:
在这里插入图片描述
其中IV(a)的计算方式如下:
在这里插入图片描述
其被称为属性a的“固有值”,属性a的可能取值数目越多,则IV(a)的值通常会越大。但此时也任然需要注意的是,在信息增益率的算法影响下,可选数目小的属性会反过来受到青睐,因此在C4.5算法中使用了一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择信息增益率最高的一个。

基尼指数(被用于CART算法)

与上面不同的是,CART决策树使用一种特殊的基尼指数来选择划分属性,基尼值定义如下:
在这里插入图片描述
简单来说,基尼值反应了从数据集p中随机抽取两个样本,其类别标记不一致的概率。因此,基尼值越小,则数据集的纯度越高。
在这里插入图片描述
而上图所展示的为基尼指数的定义,其中a为某一属性。

实验准备

本次实验代码部分有参考人民邮电出版社的《机器学习实战》的代码;且在此基础上重新二次修改和重写而成。

数据集

依照实验要求,本次实验所用的数据集为笔者手动编写;题材是关于今天晚上要不要玩游戏,包含待游玩的游戏类型当天的学业任务剩余多少是否吃过晚饭今天天气如何今天是周几这五个属性。数据集的截图示意如下
决策树 基于python实现ID3,C4.5,CART算法_第3张图片
总计含有数据记录109条。

算法大体流程

1.将所有的数据放于一个创建的根节点。
2.根据所用算法,分析生成一个现下最优的划分属性,并且根据这个属性进行对所有数据的划分。
3.上一步所得结果为空或者大多数为同一类别时,生成对应的叶节点;若不满足,则递归执行上一步。

实验代码

训练集数据读入

def createDataSet1():    # 读入训练集数据
    dataSet=[]
    with open('dataset.txt','r') as f:
        line = f.readline().strip()
        lineStr = line.split(" ")
        while line:
            lineStr = line.split(" ")
            dataSet.append(lineStr)
            line = f.readline().strip()
    labels_t=[]
    labels_t.append(dataSet[0])
    labels=sum(labels_t,[])
    dataSet.pop(0)
    return dataSet,labels

上面的代码较为简单,即格式化的读入文本中的数据,将第一行保存为列表格式的标签;剩下的行以每一行为一个列表保存为一个含有所有记录的大列表

信息熵代码

def calcShannonEnt(dataSet):  # 计算数据的熵(entropy)
    numEntries=len(dataSet)  # 数据条数
    labelCounts={}
    for featVec in dataSet:
        currentLabel=featVec[-1] # 每行数据的最后一个字(类别)
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1  # 统计有多少个类以及每个类的数量
    shannonEnt=0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries # 计算单个类的熵值
        shannonEnt-=prob*log(prob,2) # 累加每个类的熵值
    return shannonEnt

无它,关于信息熵的计算代码,需要注意的地方已经写再上面的代码行里。这一部分的代码时最重要的代码部分,其中传入的dataset是上面提到的大列表

算法流程结构(ID3和C4.5部分)

def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]  # 类别:玩或不玩
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet) #选择最优特征
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}} #分类结果以字典形式保存
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
                            (dataSet,bestFeat,value),subLabels)
    return myTree

上面的代码就是构建决策树的核心代码,可以看出来这是一个递归形式的函数;其中本函数调用的majorityCnt和chooseBestFeatureToSplit函数将在下面给出。可以看到最终结果会被以字典的形式储存起来。

def majorityCnt(classList):    #按分类后类别数量排序,比如:最后分类为10个玩游戏6个不玩游戏,则判定为玩游戏;
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
def chooseBestFeatureToSplit(dataSet):  # 选择最优的分类特征
    numFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)  # 原始的熵
    bestInfoGain = 0
    bestFeature = -1
    for i in range(numFeatures):    #求所有属性的信息增益
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  #第i列属性的取值(不同值)数集合
        newEntropy = 0
        for value in uniqueVals:    #求第i列属性每个不同值的熵*他们的概率
            subDataSet = splitDataSet(dataSet,i,value)
            prob =len(subDataSet)/float(len(dataSet))   #求出该值在i列属性中的概率
            newEntropy +=prob*calcShannonEnt(subDataSet)  #求i列属性各值对于的熵求和
        infoGain = baseEntropy - newEntropy  # 原始熵与按特征分类后的熵的差值
        if (infoGain>bestInfoGain):   # 若按某特征划分后,熵值减少的最大,则次特征为最优分类特征
            bestInfoGain=infoGain
            bestFeature = i
    return bestFeature

majorityCnt是一个简单的函数,注释已有写明;而chooseBestFeatureToSplit函数则是实现ID3算法的方法,其核心思想即是像前面所提到的那样,先通过信息熵的函数求出所有属性的信息增益,在从中选取出信息增益值最高的那个属性作为划分属性,并以此流程递归往下生成决策树。而与之相对的,C4.5算法在这一步的步骤与其有所不同

def chooseBestFeatureToSplit2(dataSet):     # 选择最优的分类特征
    numFeatures = len(dataSet[0])-1  
    baseEntropy = calcShannonEnt(dataSet)   # 原始的熵
    bestInfoGain = 0
    bestFeature = -1  
    for i in range(numFeatures):  #求所有属性的信息增益
        featList = [example[i] for example in dataSet]  
        uniqueVals = set(featList)  #第i列属性的取值(不同值)数集合
        newEntropy = 0  
        splitInfo = 0
        for value in uniqueVals:  #求第i列属性每个不同值的熵*他们的概率
            subDataSet = splitDataSet(dataSet, i , value)  
            prob = len(subDataSet)/float(len(dataSet))  #求出该值在i列属性中的概率
            newEntropy += prob * calcShannonEnt(subDataSet)  #求i列属性各值对于的熵求和
            splitInfo -= prob * log(prob, 2)
        infoGain = (baseEntropy - newEntropy) / splitInfo #求出第i列属性的信息增益率    
        if(infoGain > bestInfoGain):  # 若按某特征划分后,熵值减少的最大,则次特征为最优分类特征
            bestInfoGain = infoGain  
            bestFeature = i  
    return bestFeature

几乎两者是类似的,但区别在于C4.5算法在上面需要多求一步上文提到的固有值并用于被信息增益除去。故由上面的代码修改而来。

def splitDataSet(dataSet,axis,value): # 按某个特征分类后的数据
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec =featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

这里补充的是上面两种算法用到的按照给定特征划分数据集的操作函数。
决策树 基于python实现ID3,C4.5,CART算法_第4张图片
决策树 基于python实现ID3,C4.5,CART算法_第5张图片

上面两张图片分别展示了此时使用ID3算法和C4.5算法生成决策树的结果。

CART算法

这里之所以将CART算法单列,就是由于CART算法涉及到基尼值的计算,所以没有办法由上面的代码简单修改得来,进行了一些比较大幅度的修改后,得到的基尼值算法如下:

def calcGini(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #给所有可能分类创建字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    Gini=1.0
    #以2为底数计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        Gini-=prob*prob
    return Gini

其它余下的部分会在后面的剪枝一节中提到。

可视化

本实验采用的可视化代码为人民邮电出版社的《机器学习实战》的代码,不做过多的注解,有需要了解的朋友可以自行搜索,相关的替代代码也很多;可视化的代码并不难理解,只要会用就可以。


import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()


def retrieveTree(i):
    listOfTrees =[{'homework': {'little': {'day': {'weeksdays': {'weather': {'windy': 'no', 'sunny': 'no', 'rainy': 'no', 'cloudy': {'supper': {'none': 'no', 'eaten': 'yes'}}}}, 'weekend': {'weather': {'windy': 'yes', 'sunny': {'type': {'slg': 'no', 'rpg': 'no', 'rts': 'yes'}}, 'rainy': 'yes', 'cloudy': 'yes'}}, 'friday': {'type': {'slg': 'yes', 'rpg': 'no', 'rts': 'yes', 'fps': 'yes'}}}}, 'lot': {'day': {'weekend': {'type': {'slg': {'weather': {'sunny': 'no', 'cloudy': 'yes'}}, 'fps': 'no', 'rpg': 'no', 'rts': 'no'}}, 'friday': 'yes', 'weeksdays': {'supper': {'none': 'no', 'eaten': {'weather': {'windy': 'no', 'sunny': {'type': {'slg': 'no', 'rpg': 'no', 'rts': 'yes'}}, 'rainy': 'yes', 'cloudy': 'no'}}}}}}, 'none': {'type': {'fps': {'day': {'weeksdays': 'yes', 'weekend': 'yes', 'friday': 'no'}}, 'slg': 'yes', 'rpg': {'weather': {'windy': 'yes', 'rainy': {'supper': {'none': 'no', 'eaten': 'friday'}}, 'cloudy': 'yes'}}, 'rts': 'yes'}}}}
                  ]
    return listOfTrees[i]

需要注意的是,书上的代码为了方便选择直接将决策树导入生成,这一步应该是读取上一步保存的文本文件生成的。
可视化截图如下所示(以ID3算法为例):
决策树 基于python实现ID3,C4.5,CART算法_第6张图片

剪枝

什么是剪枝?顾名思义,即是减去多余的枝条;在决策树学习算法中,剪枝策略常常用于防止算法产生过拟合现象;由于决策树的学习过程中会有不断重复的节点划分过程,这样的特性使得决策树总是难逃过拟合的命运,使模型的泛化性能下降。那么为了保持模型得性能,会进行两种不同的剪枝策略,分别为预剪枝后剪枝

简单来说,预剪枝是指在决策树的生成过程中,抢先一步对每个节点进行划分前的估计,若不能带来泛化性能的提示,则拒绝继续产生划分,转而直接生成叶节点。

但显而易见的是,这么做会丢失一些可能存在的高价值深层节点(被提前减去了),导致其可能存在欠拟合的风险。

后剪枝则是先生成一颗完整的决策树,再自底向上的考察每个节点进行减去是否会带来泛化性能的提升。但相对的,这样会使这种策略的时间成本极大的超过预剪枝。

下面会给出笔者参考了他人代码后得到的CART算法的预剪枝实现,含大量代码,不感兴趣的朋友可以跳过

#计算数据集的基尼指数
def calcGini(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #给所有可能分类创建字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    Gini=1.0
    #以2为底数计算香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        Gini-=prob*prob
    return Gini

def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
    
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet
 
 
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
    numFeatures=len(dataSet[0])-1
    bestGiniIndex=100000.0
    bestFeature=-1
    bestSplitDict={}
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        #对连续型特征进行处理
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
            #产生n-1个候选划分点
            sortfeatList=sorted(featList)
            splitList=[]
            for j in range(len(sortfeatList)-1):
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
            
            bestSplitGini=10000
            slen=len(splitList)
            #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
            for j in range(slen):
                value=splitList[j]
                newGiniIndex=0.0
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
                prob0=len(subDataSet0)/float(len(dataSet))
                newGiniIndex+=prob0*calcGini(subDataSet0)
                prob1=len(subDataSet1)/float(len(dataSet))
                newGiniIndex+=prob1*calcGini(subDataSet1)
                if newGiniIndex<bestSplitGini:
                    bestSplitGini=newGiniIndex
                    bestSplit=j
            #用字典记录当前特征的最佳划分点
            bestSplitDict[labels[i]]=splitList[bestSplit]
            
            GiniIndex=bestSplitGini
        #对离散型特征进行处理
        else:
            uniqueVals=set(featList)
            newGiniIndex=0.0
            #计算该特征下每种划分的信息熵
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))
                newGiniIndex+=prob*calcGini(subDataSet)
            GiniIndex=newGiniIndex
        if GiniIndex<bestGiniIndex:
            bestGiniIndex=GiniIndex
            bestFeature=i
    #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
    #即是否小于等于bestSplitValue
    #并将特征名改为 name<=value的格式
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':      
        bestSplitValue=bestSplitDict[labels[bestFeature]]        
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature]<=bestSplitValue:
                dataSet[i][bestFeature]=1
            else:
                dataSet[i][bestFeature]=0
    return bestFeature

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)
    
def createTree(dataSet,labels,data_full,labels_full,data_test):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    temp_labels=copy.deepcopy(labels)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    if type(dataSet[0][bestFeat]).__name__=='str':
        currentlabel=labels_full.index(labels[bestFeat])
        featValuesFull=[example[currentlabel] for example in data_full]
        uniqueValsFull=set(featValuesFull)
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    del(labels[bestFeat])
    #针对bestFeat的每个取值,划分出一个子树。
    for value in uniqueVals:
        subLabels=labels[:]
        if type(dataSet[0][bestFeat]).__name__=='str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
         (dataSet,bestFeat,value),subLabels,data_full,labels_full,\
         splitDataSet(data_test,bestFeat,value))
    if type(dataSet[0][bestFeat]).__name__=='str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value]=majorityCnt(classList)
    if testing(myTree,data_test,temp_labels)<testingMajor(majorityCnt(classList),data_test):
        return myTree
    return majorityCnt(classList)
    def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    if '<=' in firstStr:
        featvalue=float(re.compile("(<=.+)").search(firstStr).group()[2:])
        featkey=re.compile("(.+<=)").search(firstStr).group()[:-2]
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(featkey)
        if testVec[featIndex]<=featvalue:
            judge=1
        else:
            judge=0
        for key in secondDict.keys():
            if judge==int(key):
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    else:
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex]==key:
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    return classLabel
 
 
def testing(myTree,data_test,labels):
    error=0.0
    for i in range(len(data_test)):
        if classify(myTree,labels,data_test[i])!=data_test[i][-1]:
            error+=1
    print 'myTree %d' %error
    return float(error)
    
def testingMajor(major,data_test):
    error=0.0
    for i in range(len(data_test)):
        if major!=data_test[i][-1]:
            error+=1
    print 'major %d' %error
    return float(error)

你可能感兴趣的:(决策树,python,算法)