机器学习实战第三章 决策树

内容简介:
  • 决策树的构造
  • 在Python中使用Matplotlib注解绘制树形图
  • 测试和存储分类器
  • 示例:使用决策树判断隐形眼镜类型

Part 1 决策树的构造
决策树优缺点:
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能会产生过度匹配问题
算法流程:
在构造决策树时,我们首先要解决的问题是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的效果,我们必须评估每个特征。完成测试之后,原始数据集就被划分成几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某分支下的数据属于同一类型,则无需继续划分。如果数据子集内的数据不属于同一类型,则需要重复划分数据子集。划分数据子集与划分原始数据集方法相同,直到所有具有相同类型的数据被划分到一个数据子集内。

创建分支的伪代码createBranch()如下:
检测数据集中的每个子项是否属于同一分类:
If so return 类标签
Else:
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch()并增加返回的结果到分
支节点中
return 分支节点
1.1信息增益
信息增益:划分数据集的原则是使无序的数据变的有序。划分数据前后信息发生的变化叫做信息增益。计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
熵:定义为信息的期望值。如果待分类的事务可能划分在多个分类中,则符号 x i 的 信 息 定 义 为 : x_i的信息定义为: xi l ( x i ) = − l o g 2 p ( x i ) l(x_i)=-log_2^{p(x_i)} l(xi)=log2p(xi)
其中 p ( x i ) p(x_i) p(xi)$是选择该分类的概率。为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,公式如下:
H = − ∑ i = 1 n p ( x i ) l o g 2 p ( x i ) H=-\sum_{i=1}^{n} {p(x_i)log_2^{p(x_i)}} H=i=1np(xi)log2p(xi)

下面用Python计算信息熵

from math import  log
#计算给定数据集的香农熵
def  calcShannonEnt(dataSet):
    numEntries = len(dataSet)#计算数据实例总数
    labelCounts = {}#创建数据字典
    for featVec in dataSet:
        currentLabel = featVec[-1]#键值是最后一列的值
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0 
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt

创建一个数据集,前两列分别表示不浮出水面是否可以生存,是否有脚蹼,最后一列表示类别是否是鱼

def createDataSet():
    dataSet=[[1,1,'yes'],
            [1,1,'yes'],
            [1,0,'no'],
            [0,1,'no'],
            [0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels 
myDat,labels=createDataSet()
calcShannonEnt(myDat) 
#0.9709505944546686

计算得到该数据集的熵为0.9709505944546686
1.2划分数据集
我们将对每个特征划分数据集的结果计算一次熵,然后判断按照哪个特征划分数据集是最好的划分方式。
可以将这个函数理解为当我们按照某个特征划分数据集时,将所有符合要求的元素抽取出来。

#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
#axis:划分数据集的特征  value:需返回特征的值
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

这里需要注意append和extend的区别。

a=[1,2,3]
b=[4,5,6]
a.append(b)
a 
#[1, 2, 3, [4, 5, 6]]

a=[1,2,3]
b=[4,5,6]
a.extend(b)
a 
#[1, 2, 3, 4, 5, 6]

在前面创建的样本上测试一下函数splitDataSet()

splitDataSet(myDat,0,1) 
#[[1, 'yes'], [1, 'yes'], [0, 'no']]

splitDataSet(myDat,0,0)  
#[[1, 'no'], [1, 'no']]

接下来我们遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的数据划分方式。

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    #计算整个数据集的香农熵
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):#遍历数据集所有特征
        featList=[example[i] for example in dataSet]
        #取出数据集中所有第i个特征值的值
        uniqueVals = set(featList)
        newEntropy=0.0
        for value in uniqueVals:
            subDataset=splitDataSet(dataSet,i,value)
            prob=len(subDataset)/float(len(dataSet))
            newEntropy+=prob*calcShannonEnt(subDataset)
        infoGain=baseEntropy-newEntropy
        if (infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i 
    return bestFeature        

在调用函数chooseBestFeatureToSplit()时,需要满足两个数据要求:1.数据必须是一种由列表元素组成的列表,且列表元素长度相同,2.数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。

chooseBestFeatureToSplit(myDat)
#0

测试结果告诉我们,第0个特征是划分数据集最好的特征。
1.3递归构建决策树
数据集构造决策树的工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集。第一次划分结束后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。
递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有的实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。
如果数据已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们采用多数表决的方式。

#多数表决决定叶子节点的分类
import operator
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0 
        classCount[vote]+=1
    sorted_classCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sorted_classCount[0][0]  
#创建树的函数代码
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    #所有数据集类标签
    if classList.count(classList[0])==len(classList):
        return classList[0] #递归终止条件1:类别完全相同则停止继续划分
    if len(dataSet[0])==1:
        return majorityCnt(classList)
        #递归终止条件2:用完所有特征分类仍不唯一
    bestFeat=chooseBestFeatureToSplit(dataSet)
    #选取划分效果最好的特征
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}#存储树的信息
    del(labels[bestFeat])
    featVals=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featVals)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    
    return myTree
myTree = createTree(myDat,labels) 
myTree 
#{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

Part 2 在Python中使用Matplotlib注解绘制树形图
2.1 Matplotlib注解

#使用文本注解绘制树节点
#定义文本框和箭头格式
import matplotlib.pyplot as plt 
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
#定义决策节点文本框格式
leafNode = dict(boxstyle="round4",fc="0.8")
#定义叶节点文本框格式
arrow_args = dict(arrowstyle="<-")
#定义箭头格式

plotNode()函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。

#定义绘图区
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
                           va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
#nodeTxt:注释文本的内容
#xy:被注释的坐标点位置,二维元组形如(x,y)
#xycoords:被注释点的坐标系属性,'axes fraction':以子绘图区左下角为参考,单位是百分比
#xytext:注释文本的坐标点位置,也是二维元组,默认与xy相同
#textcoords :注释文本的坐标系属性,默认与xycoords属性值相同,也可设为不同的值。
# va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
#arrowprops:箭头的样式,dict(字典)型数据,如果该属性非空,则会在注释文本和被注释点之间画一个箭头

这里createPlot()函数与后面有些不同,随着内容深入,我们将逐步添加缺失的代码。

def createPlot():
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1=plt.subplot(111,frameon=False)
    #111”表示“1×1网格,第一子图,frameon表示是否绘制坐标轴矩形边框,设成True就会带边框 
    plotNode('decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

测试一下,结果如下:

createPlot() 

机器学习实战第三章 决策树_第1张图片
2.2构造注解树
要绘制一棵完整的树需要知道有多少个树节点,以便确定x轴的长度,还需要知道树有多少层,以便正确确定y轴的高度。

#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs    
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0] 
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth 
    return maxDepth 
#储存树的信息
def retrieveTree(i):
    listOfTrees=[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                {'no surfacing': {0: 'no', 1: {'flippers': {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i] 
myTree=retrieveTree(0)
getNumLeafs(myTree)
#3
getTreeDepth(myTree)
#2

现在可以将前面的方法组合在一起,绘制一棵完整的树。

#plotMidText函数在父子节点之间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0] 
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]  
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),' ')
    plt.show()

机器学习实战第三章 决策树_第2张图片
Part3 测试和存储分类器
3.1测试算法:使用决策树执行分类

#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0] 
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel

测试结果如下:

classify(myTree,labels,[1,0]) 
#'no'
classify(myTree,labels,[1,1])  
#'yes'

3.2使用算法:决策树的存储

#存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)
storeTree(myTree,'classifierStorage.txt') 

grabTree('classifierStorage.txt')  

Part 4 示例:使用决策树预测隐形眼镜类型

#使用决策树预测隐形眼镜类型
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]

lensesLabels=['age','prescript','astigmatic','tearRate'] #prescript处方:myope=近视的人  hyper=远视眼 astigmatic:散光

lensesTree=createTree(lenses,lensesLabels) 
lensesTree 
#{'tearRate': {'normal': {'astigmatic': {'no': {'age': #{'pre': 'soft',
#      'presbyopic': {'prescript': {'hyper': 'soft', #'myope': 'no lenses'}},
#      'young': 'soft'}},
#    'yes': {'prescript': {'hyper': {'age': {'pre': 'no #lenses',
#        'presbyopic': 'no lenses',
#        'young': 'hard'}},
#      'myope': 'hard'}}}},
#  'reduced': 'no lenses'}}

机器学习实战第三章 决策树_第3张图片

你可能感兴趣的:(机器学习)