机器学习实战--决策树算法实例之判断海洋生物(ID3)

1.实例描述:

下表中有5组数据,两个特征,根据着两组特征判断这个样本是不是鱼类。

海洋生物数据
  不付出水面是否可以生存(no surfacing) 是否有脚蹼(flippers) 属于鱼类
1 1 1 yes
2 1 1 yes
3 1 0 no
4 0 1 no
5 0 1 no

 

2.算法实现的功能:

1.构造决策树

2.用matplotlib画出构造的决策树

3.给定一组数据,判断其分类。

3.代码实现:

3.1构造决策树:

#计算给定数据集的香农熵
from math import log
def calcShannonEnt(dataSet):
    num=len(dataSet)                       #数据集的样本数量
    labelCount={}                          #创建一个数据字典,它的键是数据集最后一列的数据,集样本的类别;它的值是该分类中的样本数量
    #计算每种类别下的样本数量,并将其放在字典中对应的键下
    for featureVec in dataSet:
        label=featureVec[-1]               #取样本中的最后一个值
        if label not in labelCount.keys():
            labelCount[label]=1
        else:
            labelCount[label]+=1
    #计算数据集的熵
    shannonEnt=0.0
    for key in labelCount.keys():
        pro=float(labelCount[key])/num
        shannonEnt-=pro*log(pro,2)
    return shannonEnt

#按照给定的特征划分数据集
def splitDataSet(dataSet,feature,value):        #参数:带划分的数据集、划分数据集的特征、特征值
    reDataSet=[]
    for featureVector in dataSet:
        if featureVector[feature]==value:
            reduceFeature=featureVector[:feature]
            reduceFeature.extend(featureVector[feature+1:])
            reDataSet.append(reduceFeature)
    return reDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numOfFeature=len(dataSet[0])-1
    baseShannon=calcShannonEnt(dataSet)                          #
    bestShannon=0.0
    bestFeature=-1
    for i in range(numOfFeature):
        featureList=[featureVector[i]for featureVector in dataSet]#用列表推导式将第i个特征的值提取出来
        featureSet=set(featureList)                               #利用集合的互异性找出特征的不同取值
        newShannon=0.0
        for value in featureSet:
            subDataSet=splitDataSet(dataSet,i,value)              #按照不同的特征划分数据集
            #求新划分的数据集的香农熵
            prob=float(len(subDataSet))/float(len(dataSet))
            newShannon+=prob*calcShannonEnt(subDataSet)
        shannon=baseShannon-newShannon
        if(shannon>bestShannon):
            bestShannon=shannon
            bestFeature=i
    return bestFeature

#多数表决法定义叶子节点的分类
import operator
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=1
        else:
            classCount[vote]+=1
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
#递归构建决策树
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    #递归函数第一个停止的条件:所有类标签完全相同,直接返回该类标签
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #递归函数的第二个停止条件:使用完所有特征,仍不能将数据集划分成仅包含唯一类别的分组。使用多数表决法决定叶子节点的分类
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    #开始创建决策树
    bestFeature=chooseBestFeatureToSplit(dataSet)  #选择划分数据集最好的特征的索引
    bestFeatureLabel=labels[bestFeature]           #根据特征的索引提取索引的名称
    decisionTree={bestFeatureLabel:{}}             #将此特征作为树的根节点
    del labels[bestFeature]                        #将已放进树中的特征从特征标签中删除
    featrueValues=[example[bestFeature]for example in dataSet]  #提取所有样本关于这个特征的取值
    uniqueVals=set(featrueValues)                               #应用集合的互异性,提取这个特征的不同取值
    for value in uniqueVals:                       #根据特征的不同取值,创建这个特征所对应结点的分支
        subLabels=labels[:]
        decisionTree[bestFeatureLabel][value]=createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
    return decisionTree

3.2用matplotlib画出构造的决策树

#获取叶节点的数目,在绘制决策树时确定x轴的长度
def getNumLeafs(tree):
    numOfLeaf=0
    firstNode,=tree.keys()
    second=tree[firstNode]
    #测试节点的数据类型,若不是字典类型,则表示此节点为叶子节点
    for key in second.keys():
        if type(second[key]).__name__=='dict':
            numOfLeaf+=getNumLeafs(second[key])
        else:
            numOfLeaf+=1
    return numOfLeaf

#计算树的深度,在绘制决策树时确定y轴的高度
def getTreeDepth(tree):
    depthOfTree=0
    firstNode,=tree.keys()
    second=tree[firstNode]
    for key in second.keys():
        if type(second[key]).__name__=='dict':
            thisNodeDepth=getTreeDepth(second[key])+1
        else:
            thisNodeDepth=1
        if thisNodeDepth>depthOfTree:
            depthOfTree=thisNodeDepth
    return depthOfTree

#用matplotlib绘制决策树
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle='sawtooth',fc='0.8')       #决策节点;设置文本框的类型和文本框背景灰度,范围为0-1,0为黑,1为白,不设置默认为蓝色
leafNode=dict(boxstyle='round4',fc='1')               #设置叶子节点文本框的属性
arrow_args=dict(arrowstyle='<-')

#绘制节点
#annotate(text,xy,xycoords,xytext,textcoords,va,ha,bbox,arrowprops)
#xy表示进行标注的点的坐标
#xytext表示标注的文本信息的位置
#xycoords与textcoords分别为xy和xytext的说明,默认为data
#va,ha设置文本框中文字的位置,va表示竖直方向,ha表示水平方向
def plotNode(nodeTxt,nodeIndex,parentNodeIndex,nodeType):     #形参:文本内容,文本的中心点,箭头指向文本的点,点的类型
    plt.annotate(nodeTxt,xy=parentNodeIndex,xycoords='axes fraction',
                            xytext=nodeIndex,textcoords='axes fraction',
                            va='center',ha='center',bbox=nodeType,
                            arrowprops=arrow_args)
#在父子节点之间添加注释
def plotMidText(thisNodeIndex,parentNodeIndex,text):
    xmid=(parentNodeIndex[0]-thisNodeIndex[0])/2.0+thisNodeIndex[0]
    ymid=(parentNodeIndex[1]-thisNodeIndex[1])/2.0+thisNodeIndex[1]
    plt.text(xmid,ymid,text)                            #在指定位置添加注释

def plotTree(tree,parentNodeIndex,midTxt):
    global xOff
    global yOff
    numOfLeafs=getNumLeafs(tree)
    nodeTxt,=tree.keys()
    nodeIndex=(xOff+(1.0+float(numOfLeafs))/2.0/treeWidth,yOff)  #计算节点的位置
    plotNode(nodeTxt, nodeIndex, parentNodeIndex, decisionNode)
    plotMidText(nodeIndex,parentNodeIndex,midTxt)
    secondDict=tree[nodeTxt]
    yOff=yOff-1.0/treeDepth
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],nodeIndex,str(key))
        else:
            xOff=xOff+1.0/treeWidth
            plotNode(secondDict[key],(xOff,yOff),nodeIndex,leafNode)
            plotMidText((xOff,yOff),nodeIndex,str(key))
    yOff=yOff+1.0/treeDepth

def createPlot(tree):               #绘制决策树的主函数
    fig=plt.figure('DecisionTree',facecolor='white')   #创建一个画布,命名为'decisionTree',画布颜色为白色
    fig.clf()                                          #清空画布
    createPlot.ax1=plt.subplot(111,frameon=False)      #111:将画布分成1行1列,去第一块画布;frameon:是否绘制矩形坐标框
    #设置两个全局变量xOff和yOff,追踪已绘制节点的位置,计算放置下一个节点的恰当位置。
    global xOff
    xOff=-0.5/treeWidth
    global yOff
    yOff=1.0
    plotTree(tree,(0.5,1.0),'')
    plt.xticks([])
    plt.yticks([])
    plt.show()

注:

1.treeWidth和treeDepth是我们在函数外声明的变量,用于存储树的宽度和深度。我们使用这两个变量计算树节点的摆放位置,这样可以讲述绘制在水平方向和竖直方向的中心位置。

2.代码中声明了xOff和yOff两个全局变量来追踪已绘制的节点位置,以及放置下一个节点的恰当位置。

 

3.3使用决策树执行分类操作:给定一组数据,判断其分类

#使用决策树执行分类
def classify(inputTree,featureLabels,testVector):
    firstNode,=inputTree.keys()
    secondDict=inputTree[firstNode]
    featureIndex=featureLabels.index(firstNode)
    for key in secondDict.keys():
        if testVector[featureIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featureLabels,testVector)
            else:
                classLabel=secondDict[key]
    return classLabel

3.4使用pickle模块存储决策树

构造决策树是很耗时的任务,即使处理很小的数据集,也要花费好几秒的时间。为了节省构造数据集的时间,最好在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块的pickle序列化对象。序列化对象可以在磁盘上保存对象,并在需要的时候读出来。

实现如下:

#使用pickle模块储存决策树
def storeTree(inputTree,filename):
    import pickle
    file=open(filename,'wb')
    pickle.dump(inputTree,file)
    file.close()
def loadTree(filename):
    import pickle
    file=open(filename,'rb')
    Tree=pickle.load(file)
    file.close()
    return Tree

3.5代码测试:

dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']];
labels=['no surfacing','flippers']
decisionTree=createTree(dataSet,labels)
storeTree(decisionTree,'decisionTree')

myTree=loadTree('decisionTree')
featureLabels=['no surfacing','flippers']

treeWidth=float(getNumLeafs(myTree))
treeDepth=float(getTreeDepth(myTree))
createPlot(myTree)
print(classify(myTree,featureLabels,[1,0]))

构造决策树图的结果如下:

机器学习实战--决策树算法实例之判断海洋生物(ID3)_第1张图片

输入数据判断分类的输出结果:

no

 

4.部分函数详细解释:

1.calcShannonEnt(dataSet):计算香农熵

2.splitDataSet(dataSet,feature,value):划分数据集

3.chooseBestFeatureToSplit(dataSet):选择最好的数据集划分方式

4.majorityCnt(classList);createTree(dataSet,labels):多数表决法+递归构建决策树

5.python模块pickle序列化对象的应用:pickle模块

5.总结:

决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据时,首先需要测量集合种数据的不一致性,也就是熵,然后寻找最优方案划分数据集,知道找到数据集中的所有数据属于同一类。本篇文章中用于构造决策树的算法为ID3算法,这个算法无法直接处理数值型的数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但是存在太多的特征划分时,ID3算法仍然会存在一些其他的问题。

构建决策树时,一般不构造新的数据结构,而是使用python语言内嵌的数据结构字典存储树节点信息。

使用matplotlib的注解功能,可以将存储的树结构转化为容易理解的图形。python语言的pickle模块可用于存储决策树的结构。决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过剪裁决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配的问题(这种方法,上述例子中没有涉及)。

关于构造决策树的算法,还有C4.5和CART。

你可能感兴趣的:(python,Machine,Learing)