《机器学习实战》代码记录--决策树

数据格式:


no surfacing
flippers labels
1
1
1
yes
2
1
1
yes
3
1
0
no
4
0
1
no
5
0
1
no

输出样例:

打印决策树并对[1,1]分类

《机器学习实战》代码记录--决策树_第1张图片

代码:

decision_tree.py

# -*- coding:utf-8 -*-
from math import log
import operator
import sys
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
        retDataSet=[]
        for featVec in dataSet:
                if featVec[axis]==value:
                        reducedFeatVec=featVec[:axis]
                        reducedFeatVec.extend(featVec[axis+1:])
                        retDataSet.append(reducedFeatVec)
        return retDataSet


#计算数据集的熵
def calcShannonEnt(dataSet):
        numEntries=len(dataSet)
        labelCounts={}
        for featVec in dataSet:
                currentLabel=featVec[-1]
                if currentLabel not in labelCounts.keys():
                        labelCounts[currentLabel]=0
                labelCounts[currentLabel]+=1
        shannonEnt=0.0
        for key in labelCounts:
                prob=float(labelCounts[key])/numEntries
                shannonEnt-=prob*log(prob,2)
        return shannonEnt

#选择最好的数据集划分方式(最大信息增益)
def chooseBestFeatureToSplit(dataSet):
        numFeatures=len(dataSet[0])-1
        baseEntropy=calcShannonEnt(dataSet)
        bestInfoGain=0.0;bestFeature=-1
        for i in range(numFeatures):
                featList=[example[i] for example in dataSet]
                uniqueVals=set(featList)
                newEntropy=0.0
                for value in uniqueVals:
                        subDataSet=splitDataSet(dataSet,i,value)
                        prob=len(subDataSet)/float(len(dataSet))
                        newEntropy+=prob*calcShannonEnt(subDataSet)
                infoGain=baseEntropy-newEntropy
#               print i,infoGain
                if(infoGain>bestInfoGain):
                        bestInfoGain=infoGain
                        bestFeature=i
        return bestFeature


#投票表决部分
def majorityCnt(classList):
        classCount={}
        for vote in classList:
                if vote not in classCount.keys():
                        classCount[vote]=0
                classCount[vote]+=1
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

#创建决策树
def createTree(dataSet,labels):
        classList=[example[-1] for example in dataSet]
        if classList.count(classList[0])==len(classList):#类别完全相同则停止划分
                return classList[0]
        if len(dataSet[0])==1:#遍历完所有特征时返回出现次数最多的       
                return majoryCnt(classList)
        bestFeat=chooseBestFeatureToSplit(dataSet)
        bestFeatLabel=labels[bestFeat]
        myTree={bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues=[example[bestFeat] for example in dataSet]
        uniqueVals=set(featValues)
        for value in uniqueVals:
                subLabels=labels[:]
                myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
        return myTree

def classify(inputTree,featLabels,testVec):
        firstStr=inputTree.keys()[0]
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(firstStr)
        for key in secondDict.keys():
                if testVec[featIndex]==key:
                        if type(secondDict[key]).__name__=='dict':
                                classLabel=classify(secondDict[key],featLabels,testVec)
                        else:
                                classLabel=secondDict[key]
        return classLabel


def storeTree(inputTree,filename):
        import pickle
        fw=open(filename,'w')
        pickle.dump(inputTree,fw)
        fw.close()

def grabTree(filename):
        import pickle
        fr=open(filename)
        return pickle.load(fr)

if __name__=='__main__':
        dataset = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]
        labels = ['no surfacing','flippers']

        tree=createTree(dataset,labels)
        print tree
        storeTree(tree,'firstTry.txt')
        tree2=grabTree('firstTry.txt')
        print classify(tree2,['no surfacing','flippers'],eval(sys.argv[1]))

treePLotter.py

#-*-coding:utf-8 -*-
import matplotlib.pyplot as plt
import matplotlib

import decision_tree

decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


def plotNode(nodeTxt,centerPt,parentPt,nodeType):
        zwfont=matplotlib.font_manager.FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')
        createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args,fontproperties=zwfont)


def createPlot(inTree):
        fig=plt.figure(1,facecolor='white')
        fig.clf()
        axprops=dict(xticks=[],yticks=[])
        createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
        plotTree.totalW=float(getNumLeafs(inTree))
        plotTree.totalD=float(getTreeDepth(inTree))
        plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
        plotTree(inTree,(0.5,1.0),'')
        plt.show()

def getNumLeafs(myTree):
        numLeafs=0
        firstStr=myTree.keys()[0]
        print 'firstStr',firstStr
        secondDict=myTree[firstStr]
        print 'secondDict',secondDict
        for key in secondDict.keys():
                if type(secondDict[key]).__name__=='dict':
                        numLeafs+=getNumLeafs(secondDict[key])
                else:
                        numLeafs+=1
        return numLeafs

def getTreeDepth(myTree):
        maxDepth=0
        firstStr=myTree.keys()[0]
        secondDict=myTree[firstStr]
        for key in secondDict.keys():
                if type(secondDict[key]).__name__=='dict':
                        thisDepth=1+getTreeDepth(secondDict[key])
                else:
                        thisDepth=1
                if thisDepth>maxDepth:
                        maxDepth=thisDepth
        return maxDepth

def plotMidTex(cntrPt,parentPt,txtString):
        xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
        yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
        createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
        numLeafs=getNumLeafs(myTree)
        depth=getTreeDepth(myTree)
        firstStr=myTree.keys()[0]
        cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
        plotMidTex(cntrPt,parentPt,nodeTxt)
        plotNode(firstStr,cntrPt,parentPt,decisionNode)
        secondDict=myTree[firstStr]
        plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
        for key in secondDict.keys():
                if type(secondDict[key]).__name__=='dict':
                        plotTree(secondDict[key],cntrPt,str(key))
                else:
                        plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
                        plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
                        plotMidTex((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
        plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

if  __name__=='__main__':
#       createPlot()
        dataset = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]
        labels = ['no surfacing','flippers']
        tree=decision_tree.createTree(dataset,labels)
        createPlot(tree)



你可能感兴趣的:(机器学习实战)