基于pyhton3.6-机器学习实战-决策树tree代码解释



本人是一名数学系研究生,于2017年底第一次接触python和机器学习,作为一名新手,欢迎与大家交流。

我主要给大家讲解代码,理论部分给大家推荐3本书:

《机器学习实战中文版》

《机器学习》周志华

《统计学习方法》李航

以上3本书,第一本是基于python2的代码实现;剩余两本主要作为第一本书理论省略部分的补充,理论大部分都讲得很细。

博客上关于机器学习实战理论解释都很多,参差不齐,好作品也大都借鉴了以上3本书,网上有很多电子版的书。

与其看看一些没用的博客,真心不如以上3本书有收获。

说实话,学习一定要静下心来,切忌浮躁。不懂可以每天看一点,每天你懂一点,天天积累就多了。

操作系统:windows8.1

python版本:python3.6

运行环境:spyder(anaconda)

# -*- coding: utf-8 -*-
"""
Created on Thu Jan 11 20:27:57 2018

@author: Lelouch_C.C
"""

#创建createDataSet()函数
def createDataSet():
    dataSet = [[ 1, 1, 'yes'],         
            [1, 1, 'yes'],
            [1, 0, 'no'],
            [0, 1, 'no'],
            [0, 1, 'no']]
    labels=['no surfacing','flippers']           
    #这里的labels指的是储存特征名称的列表,不是储存类别标签,下不复述
    return dataSet,labels

from math import log

#计算给定数据集的shannon entropy
def calcShannonEnt(dataSet):
    numEntires = len(dataSet)                     #返回数据集的行数
    labelCounts = {}                              
    #初始化一个空字典,用来保存每个类别标签出现的次数
    for featVec in dataSet:                       #对每组特征向量进行统计
        currentLabel = featVec[-1]                #提取类别标签
        ####################类别标签计数#################
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
            #如果类别标签没有放入统计次数的字典,先计0
        labelCounts[currentLabel] += 1     
        #如果类别标签放入统计次数的字典,+1      
        #################################################
    shannonEnt = 0.0                              #经验熵(香农熵)
    for key in labelCounts:                       #计算香农熵
        prob = float(labelCounts[key]) / numEntires  #选择该类别标签的概率
        shannonEnt -= prob * log(prob, 2)            #shannon entropy计算公式
    return shannonEnt                                #返回经验熵(香农熵)
"""
if __name__ =='__main__':
    myDat,mylables=createDataSet()
    print('myDat=',myDat)
    #输出:myDat= [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    print('calcShannonEnt(myDat)=',calcShannonEnt(myDat))
    #输出:calcShannonEnt(myDat)= 0.9709505944546686
#"""
    
def splitDataSet(dataSet,axis,value):  
    """
    函数说明:按照给定的特征的特征值划分数据集
    参数:待划分数据集dataSet
          划分数据集的列/特征axis
          要划分的列/特征的特征值value
    返回值:按照给定的特征划分好的数据集retDataSet
    """  
    retDataSet = []  
    for featVec in dataSet:                    #遍历数据集中的每个特征
        if featVec[axis] == value: 
            #如果每个样本featVec的特征值为value,则:
            ######将每个样本划分特征的特征值value去掉后加入retDataSet#######
            reducedFeatVec = featVec[:axis] 
            # 取每个样本featVec的0-axis个数据,不包括axis,赋给reducedFeatVec
            reducedFeatVec.extend(featVec[axis+1:])  
            # 取featVec的axis+1到最后的数据,放到reducedFeatVec的后面  
            retDataSet.append(reducedFeatVec)  
            # 将reducedFeatVec添加到分割后的数据集retDataSet中,
            #同时reducedFeatVec,retDataSet中没有了axis列的特征值
            ##############################################################
    return retDataSet                          #返回分割后的数据集
"""
if __name__=='__main__':
    myDat,mylables=createDataSet()
    print('splitDataSet(myDat,0,1)=',splitDataSet(myDat,0,1))
    #输出:splitDataSet(myDat,0,1)= [[1, 'yes'], [1, 'yes'], [0, 'no']]
    print('splitDataSet(myDat,0,0)=',splitDataSet(myDat,0,0))
    #输出:splitDataSet(myDat,0,1)= [[1, 'no'], [1, 'no']]
#"""

#选择最好的数据集划分方式,选择使分割后信息增益最大的特征,即对应的列 
def chooseBestFeatureToSplit(dataSet): 
    numFeatures = len(dataSet[0]) - 1
    #获取特征的数目,len(dataSet[0])是列数 ,但最后一列是类别标签,不是特征,所以-1
    baseEntropy = calcShannonEnt(dataSet)  #计算整个数据集原始的信息熵 
    bestInfoGain = 0.0                     #初始化最好的信息增益 
    bestFeature = -1                       #初始化分割后信息增益最好的特征
    for i in range(numFeatures):  
        featList = [example[i] for example in dataSet] 
        # 用一个列表推到式,取出每个样本的第i列特征值赋给featList 
        uniqueVals = set(featList)  
        # 将每列的特征对应的值放到一个集合中,使得特征值不重复  
        newEntropy = 0.0                   #初始化分割后的信息熵 
        for value in uniqueVals:  
            # 遍历特征列的所有值(值是唯一的,重复值已经合并),分割并计算信息增益
            subDataSet = splitDataSet(dataSet, i, value) 
            # 按照特征列的每个值进行数据集分割  
            prob = len(subDataSet) / float(len(dataSet)) 
            # 计算分割后的每个子集的概率(频率)
            newEntropy +=prob * calcShannonEnt(subDataSet) 
            # 计算分割后的子集的信息熵并相加,得到分割后的整个数据集的信息熵 
        infoGain = baseEntropy - newEntropy # 计算分割后的信息增益 
        if(infoGain > bestInfoGain):        # 如果分割后信息增益大于最好的信息增益
            bestInfoGain = infoGain     # 将当前的分割的信息增益赋值为最好信息增益
            bestFeature = i             # 分割的最好特征列赋为i
    return bestFeature                  # 返回分割后信息增益最大的特征列  
"""
if __name__=='__main__':
    myDat,mylables=createDataSet()
    print('chooseBestFeatureToSplit(myDat)=',chooseBestFeatureToSplit(myDat))
    #输出:chooseBestFeatureToSplit(myDat)= 0
#"""
    
import operator

#统计classList中出现此处最多的元素(类标签)
def majorityCnt(classList):            #classList是一个类别标签列表,见createTree()
    classCount = {}                    #初始化一个空字典,用来储存类别标签出现次数
    ##########统计classList中每个元素出现的次数##########
    for vote in classList:             
        if vote not in classCount.keys():
            classCount[vote] = 0   
        classCount[vote] += 1
    ###################################################
    sortedClassCount = sorted(classCount.items(), \
                              key =operator.itemgetter(1), reverse = True)  
    #根据字典的值降序排序
    return sortedClassCount[0][0]      #返回classList中出现次数最多的元素

#
def createTree(dataSet, labels):
    """
    函数说明:递归创建ID3算法决策树--只能处理离散数据
    参数:数据集dataSet
          储存特征名称的列表labels(注意:不是类别标签)
    返回值:创建好的决策树myTree
    """
    classList = [example[-1] for example in dataSet]   #classList是一个类别标签列表
    #######################定义叶节点#######################
    if classList.count(classList[0]) == len(classList):
        #如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:             #遍历完所有特征时返回出现次数最多的类别标签
        return majorityCnt(classList)
    ########################################################
    bestFeat = chooseBestFeatureToSplit(dataSet)       #选择最优特征
    bestFeatLabel = labels[bestFeat]                   #获取最优特征的名称
    myTree = {bestFeatLabel:{}}                        #根据最优特征的标签生成树
    del(labels[bestFeat])                              #删除已经使用特征名称
    featValues = [example[bestFeat] for example in dataSet]
    #利用列表推导式得到训练集中所有最优特征的特征值列表
    uniqueVals = set(featValues)                       #去掉重复的特征值
    for value in uniqueVals:             #遍历特征值,每一个特征值都会建立一个分支。
        sublabels= labels[:]                      
        myTree[bestFeatLabel][value] = createTree(\
              splitDataSet(dataSet, bestFeat, value), sublabels)
    return myTree
"""
if __name__=='__main__':
    myDat,labels=createDataSet()
    myTree=createTree(myDat,labels)
    print ('myTree=',myTree)
    #输出:myTree= {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#"""    

import matplotlib.pyplot as plt

#使用文本绘制树节点

decisionNode = dict(boxstyle="sawtooth",fc="0.8")  # 定义决策树决策点的属性
# 也可写作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}  
# boxstyle为文本框的类型,sawtooth是锯齿形,fc控制的注解框内的颜色深度   
leafNode = dict(boxstyle="round4",fc="0.8")  # 定义决策树的叶子结点的描述属性 
arrow_args = dict(arrowstyle="<-")  # 定义决策树的箭头属性,默认"->"

# 绘制结点  
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    """
    函数说明:绘制带箭头的注解
    参数:nodeTxt为要显示的文本,
          centerPt为文本的中心点,
          parentPt为父节点/指向文本的点 
    """
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
                            xytext=centerPt,textcoords='axes fraction',\
                            va="center",ha="center",bbox=nodeType,\
                            arrowprops=arrow_args) 
    
    #plt.annotate()文本注释,使用annotate()方法可以很方便地添加此类注释
    #nodeTxt为要显示的文本,被注释的地方xy=(x, y)和插入文本的地方xytext=(x, y)
    #'axes fraction'     fraction of axes from lower left
    #va=vertical axis,ha=horizontal axis,都指定中心
    #bbox定义文本框格式,arrowprops定义箭头格式
"""
# 创建绘图 
def createPlot(): 
    fig = plt.figure(1,facecolor='white') 
    # 类似于Matlab的figure,定义一个画布,背景为白色 
    fig.clf() # 把画布清空 
    createPlot.ax1 = plt.subplot(111,frameon=False) # frameon表示是否绘制坐标轴矩形 
    # createPlot.ax1为全局变量,绘制图像的句柄,
    #`subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图 
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode) # 绘制结点
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode) # 绘制结点
    plt.show() 

if __name__=='__main__':
    createPlot()
#"""  
# 获得决策树的叶子结点数目  
def getNumLeafs(myTree):  
    numLeafs = 0                     #定义叶子结点数目 
    firstStr =list(myTree.keys())[0] #获得myTree的第一个键,即第一个特征名称  
    #python3中myTree.keys()返回的是dict_keys,不在是list,
    #所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr] #根据键得到对应的值,即根据第一个特征分类的结果
    for key in secondDict.keys():   #遍历得到的secondDict 的键
        if type(secondDict[key]).__name__ == 'dict':  
            # 如果secondDict[key]为一个字典,即决策树结点 
            numLeafs += getNumLeafs(secondDict[key])  
            # 则递归的计算secondDict中的叶子结点数,并加到numLeafs上 
        else:                      # 如果secondDict[key]为叶子结点
            numLeafs += 1          # 则将叶子结点数加1 
    return numLeafs                # 返回求的叶子结点数目
      
# 获得决策树的深度  
def getTreeDepth(myTree):  
    maxDepth = 0                     #初始化树的最大深度
    firstStr =list(myTree.keys())[0] #获得myTree的第一个键,即第一个特征名称
    secondDict = myTree[firstStr]  #根据键值得到对应的值,即根据第一个特征分类的结果
    for key in secondDict.keys():  
        if type(secondDict[key]).__name__ == 'dict':# 如果secondDict[key]为一个字典
            thisDepth = 1 + getTreeDepth(secondDict[key])  
            # 则当前树的深度等于1加上secondDict的深度,只有当前点为决策树点深度才会加1
        else:                          # 如果secondDict[key]为叶子结点
            thisDepth = 1              # 则将当前树的深度设为1 
        if thisDepth > maxDepth:       # 如果当前树的深度比最大数的深度
            maxDepth = thisDepth       
    return maxDepth                    # 返回树的深度 

# 预先储存树信息,避免每次测试时都创建树的麻烦
def retrieveTree(i):  
    listOfTree = [{'no surfacing':{ 0:'no',1:{'flippers':  
                       {0:'no',1:'yes'}}}},  
                   {'no surfacing':{ 0:'no',1:{'flippers':  
                    {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}  
                  ]  
    return listOfTree[i]
"""
if __name__=='__main__':
     myTree=retrieveTree(0)
     print(getNumLeafs(myTree))
     #输出:3
     print(getTreeDepth(myTree))
     #输出:2
#"""
# 绘制中间文本  
def plotMidText(cntrPt,parentPt,txtString):  
    """
    函数说明:在父子节点间填充文本信息
    """
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]  # 求中间点的横坐标
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]  # 求中间点的纵坐标 
    createPlot.ax1.text(xMid,yMid,txtString)          # 绘制树结点

# 绘制决策树  
def plotTree(myTree,parentPt,nodeTxt):  
    # 定义并获得决策树的叶子结点数  
    #===========================计算非叶节点的情况==============================
    numLeafs = getNumLeafs(myTree)  #获取当前树的叶节点数,因为递归时向下计算子树
    depth = getTreeDepth(myTree)    #获取当前树的深度,因为递归时向下计算子树
    firstStr =list(myTree.keys())[0]#得到当前树的第一个特征 ,递归...
    cntrPt = (plotTree.xOff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) 
    #计算当前文本中心坐标,x坐标计算如下,y坐标为 plotTree.yOff,起始位置是1
    """
    首先声明,绘制图形的x轴和y轴的有效范围都是[0,1]
    plotTree.xOff追踪已经绘制节点的的位置,准确为最近绘制的一个节点的x坐标,
    在确定当前节点位置时每次都需确定当前节点总共有几个叶子节点,
    因此其叶子节点所占的总距离就确定了即为float(numLeafs)*1/plotTree.totalW=1(因为总长度为1),
    因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为
    float(numLeafs)/2.0/plotTree.totalW,
    但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,
    因此还需加上半个表格距离即为1/2/plotTree.totalW,
    则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,
    因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW
    """
    plotMidText(cntrPt, parentPt, nodeTxt)  
    #在父子节点间(当前文本与上一个文本之间)填充文本信息
    #这里说明一点:最一开始cntrPt=parentPt=(1/2,1),nodeTxt='',
    #所以刚开始父子节点间没有填文本信息
    plotNode(firstStr,cntrPt,parentPt,decisionNode)   
    #绘制决策结点带箭头的注解,同理最一开始只能绘制文本框,因为两点重合
    secondDict = myTree[firstStr]                     #进入下一层分支
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点 
    # 正是由于这样的更新,plotTree.yOff有了追踪最近绘制节点的的位置的功能
    #==========================================================================
    for key in secondDict.keys():  
         
        if type(secondDict[key]).__name__ == 'dict':  
        # 如果secondDict[key]为一棵子决策树,即字典     
            plotTree(secondDict[key],cntrPt,str(key))  #递归的绘制决策树 
        else:                                          #否则,也就是叶节点的时候
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
            #计算叶子结点的横坐标
            #同理,正是由于这样的更新,plotTree.xOff有了追踪最近绘制节点的的位置的功能  
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode)  
            # 绘制叶结点带箭头的注解
            plotMidText((plotTree.xOff,plotTree.yOff), cntrPt, str(key))  
            #在当前父子节点间(当前叶节点与上一个文本之间)填充文本信息
    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD  # 计算纵坐标  
    
def createPlot(inTree):  
    fig = plt.figure(1,facecolor='white')  # 定义一块画
    fig.clf()                              # 清空画布 
    axprops = dict(xticks=[],yticks=[])    # 定义横纵坐标轴,无内容
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)  
    # 绘制图像,无边框,无坐标轴 
    plotTree.totalW = float(getNumLeafs(inTree))  #plotTree.totalW保存的是树的宽
    plotTree.totalD = float(getTreeDepth(inTree)) #plotTree.totalD保存的是树的高 
    plotTree.xOff = - 0.5/plotTree.totalW     
    #plotTree.xOff最一开始时定义横坐标左移了半个表格,后面追踪最近绘制节点横坐标的的位置
    plotTree.yOff = 1.0                           
    #plotTree.yOff最一开始时定义纵坐标为1,后面追踪最近绘制节点纵坐标的位置
    plotTree(inTree,(0.5,1.0),'')                 # 绘制决策树 
    plt.show()                                    # 显示图像 
"""
if __name__=='__main__':
     myTree=retrieveTree(0)
     print(myTree)
     #输出:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
     createPlot(myTree)
     myTree['no surfacing'][3]='maybe'
     print(myTree)
     #输出:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
     createPlot(myTree)
     myTree1=retrieveTree(1)
     print(myTree1)
     #输出:{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
     createPlot(myTree1)
#"""
     
#测试算法
     
def classify(inputTree, featLabels, testVec):
    """
    函数说明:ID3算法决策树分类函数,也就是分类器--只能处理离散数据
    参数:inputTree已经生成的决策树
          featLabels存储特征名称的list
          testVec测试数据list,
    """
    firstStr = list(inputTree.keys())[0]        #获取决策树判断节点,递归...
    #或者firstStr = next(iter(inputTree))
    secondDict = inputTree[firstStr]            #从判断节点进入下一个层,递归...
    featIndex = featLabels.index(firstStr)      #将标签字符串转换为索引      
    #L.index(value, [start, [stop]]) -> integer -- return first index of value.                                   
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel
"""
if __name__ == '__main__':
    myDat, labels = createDataSet()
    myTree =retrieveTree(0)
    result = classify(myTree, labels, [1,0])
    print('result=',result)
    #输出:result= no
    result = classify(myTree, labels, [1,1])
    print('result=',result)
    #输出:result= yes
    result = classify(myTree, labels, [2,0])
    print('result=',result)
    #UnboundLocalError: local variable 'classLabel' referenced before assignment
    #报错是因为第一个特征已有特征值没有2,每个特征下所有特征值相当于是该特征下所有的类别
    #所以不能超出现有类别
#"""
    
#使用算法:决策树的储存
import pickle
def storeTree(inputTree, filename):
    fw=open(filename, 'wb')         #创建一个名为filename的文件,写入
    pickle.dump(inputTree, fw)     #将inputTree导入到名为filename的文件中
    fw.close()
"""
if __name__ == '__main__':
    myDat, labels = createDataSet()
    myTree = createTree(myDat, labels)
    storeTree(myTree, 'classifierStorage.txt')
#"""    
def grabTree(filename):
    fr = open(filename,'rb')
    return pickle.load(fr)
"""
if __name__ == '__main__':
    print(grabTree('classifierStorage.txt'))
    #输出:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#"""

#示例:使用决策树预测隐形眼镜类型
#样本:一共有24组数据,
#数据的Labels依次是age、prescript、astigmatic、tearRate、class,
#也就是依次是年龄,症状,否散光,眼泪数量,最终的分类标签。   
    
"""
if __name__ == '__main__':
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLabels)
    print('lensesTree=',lensesTree)
    createPlot(lensesTree)
#"""

你可能感兴趣的:(机器学习)