ID3—决策树算法

一 基本概念

ID3 算法以信息论为基础,其中以信息熵和信息增益度
度量标准,从而实现对数据的归纳分类。
熵的定义:假设有n个互不相容的事件a1,a2,a3,….,an, p(ai)表示事件ai发生的概率,则由该分布传递的信息量称为熵,记为式
这里写图片描述

举个列子
如英语有26个字母,假如每个字母在文章中出现次数平均的话,每个字母的信息量为:
I(e)=-log2(1/26)=4.7
汉字常用的有2500个,假如每个汉字在文章中出现次数平均的话,每个汉字的信息量为:
I(e)=-log2(1/2500)=11.3
如果两个系统具有同样大的信息量,如一篇用不同文字写的同一文章,由于是所有元素信息量的加和,使用的汉字就比使用英文字母要少。
ID3 –计算信息增益
在决策树分类中,假设S是训练样本集合,假定类别标号属性具有m个不同的值,定义m个不同类**C1,C2,….Cm(身高、体重、学历),设**Si是类Ci中样本的个数。对一个给定的样本分类所需的数学期望信息由下式给出
这里写图片描述
其中 pi 是任意样本属于Ci的概率,并用Si /S估计

属性A具有v个不同值{a1,a2,….,av},可以用属性A(身高)将S划分为v个子集{s1,s2,….,sv}(高中低);其中Si包含S中这样一些样本,他们在A上具有值aj。如果A选作测试属性,则这些子集对应于包含集合S的节点生长出来的分支。设Sij是子集Sj中类Ci的样本数,根据由A划分成子集的熵或平均期望信息由下式给出:
ID3—决策树算法_第1张图片
其中
ID3—决策树算法_第2张图片
定义信息增益为
这里写图片描述

Gain(A)越大,说明选择测试属性对分类提供的信息越多
依据贪婪算法,为了使下一步所需的信息量最小,要求每一次
都选择其信息增益最大的属性作为决策树的新节点。
信息增益(Gain)=信息期望H-平均信息期望E
决策树建立的关键:一个好的决策树取决于决策树跟和子树跟
的属性

二 决策树理论计算的例子

ID3—决策树算法_第3张图片
第1步计算决策属性的熵
决策属性“买计算机?”。该属性分
两类:买/不买
S1(买)=641 买的样本个数
S2(不买)= 383 不买的样本个数
S=S1+S2=1024

P1=641/1024=0.6260
P2=383/1024=0.3740
测试样本的信息期望:
H(S1,S2)=H(641,383)
=-P1Log2P1-P2Log2P2
=-(P1Log2P1+P2Log2P2)
=0.9537
第2步计算条件属性的熵
条件属性共有4个。分别是年龄、
收入、学生、信誉。
分别计算不同属性的信息增益。
第2-1步计算年龄的熵
年龄共分三个组:
青年、中年、老年
青年买与不买比例为128/256

S1(买)=128
S2(不买)= 256
S=S1+S2=384

P1=128/384
P2=256/384

H青年(S1,S2)=H(128,256)
=-P1Log2P1-P2Log2P2
=-(P1Log2P1+P2Log2P2)
=0.9183
第2-2步计算年龄的熵
年龄共分三个组:
青年、中年、老年
中年买与不买比例为256/0

S1(买)=256
S2(不买)= 0
S=S1+S2=256

P1=256/256
P2=0/256

H中年(S1,S2)=H(256,0)
=-P1Log2P1-P2Log2P2
=-(P1Log2P1+P2Log2P2)
=0
第2-3步计算年龄的熵
年龄共分三个组:
青年、中年、老年
老年买与不买比例为257/127

S1(买)=257
S2(不买)=127
S=S1+S2=384

P1=257/384
P2=127/384

H老年(S1,S2)=H(125,127)
=-P1Log2P1-P2Log2P2
=-(P1Log2P1+P2Log2P2)
=0.9157
第2-4步计算年龄的Gain
年龄共分三个组:
青年、中年、老年
所占比例
青年组 384/1024=0.375
中年组 256/1024=0.25
老年组 384/1024=0.375

计算年龄的平均信息期望
E(年龄)=0.375*0.9183+
0.25*0+
0.375*0.9157
=0.6877
G(年龄信息增益)
=0.9537-0.6877
=0.2660 (1)
第3步计算收入的熵
收入共分三个组:
高、中、低
E(收入)=0.9361
收入信息增益=0.9537-0.9361
=0.0176 (2)
第4步计算学生的熵
学生共分二个组:
学生、非学生
E(学生)=0.7811
学生信息增益=0.9537-0.7811
=0.1726 (3)
第5步计算信誉的熵
信誉分二个组:
良好,优秀
E(信誉)= 0.9048
信誉信息增益=0.9537-0.9048
=0.0453 (4)
第6步计算选择节点
年龄信息增益=0.9537-0.6877
=0.2660 (1)

收入信息增益=0.9537-0.9361
=0.0176 (2)

学生信息增益=0.9537-0.7811
=0.1726 (3)

信誉信息增益=0.9537-0.9048
=0.0453 (4)

ID3 决策树建立算法步骤
1 决定分类属性;
2 对目前的数据表,建立一个节点N
3 如果数据库中的数据都属于同一个类,N就是树叶,在树叶上
标出所属的类
4 如果数据表中没有其他属性可以考虑,则N也是树叶,按照少
数服从多数的原则在树叶上标出所属类别
5 否则,根据平均信息期望值E或GAIN值选出一个最佳属性作
为节点N的测试属性
6 节点属性选定后,对于该属性中的每个值:
从N生成一个分支,并将数据表中与该分支有关的数据收集形
成分支节点的数据表,在表中删除节点属性那一栏
如果分支数据表非空,则运用以上算法从该节点建立子树。

三 python源代码

数据集
ID3—决策树算法_第4张图片
决策树源代码1-1(主程序)

#coding=utf-8
from math import log
import operator
import copy
import tree_plot

def createDataSet():
    dataSet = [[0, 2,1, 0, 'no'],
               [0, 2,2, 1,'no'],
               [1, 2,0, 0, 'yes'],
               [2, 1,2, 0, 'yes'],
               [2, 0,1, 0, 'yes'],
               [2, 0,0, 1, 'no'],
               [1, 0,0, 1, 'yes'],
               [0, 1,2, 0, 'no'],
               [0, 0,0, 0, 'yes'],
               [2, 1,1, 0, 'yes'],
               [0, 1,0, 1, 'yes'],
               [1, 1,2, 1, 'yes'],
               [1, 2,0, 0, 'yes'],
               [2, 1,1, 1, 'no']]
    labels = ['weather', 'temperature', 'humidity', 'wind']
    return dataSet, labels

def calcShannonEnt(dataSet):#熵只和决策属性有关
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:  # the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)  # log base 2
    return shannonEnt


def splitDataSet(dataSet, axis, value):#某一属性下某一特征的决策属性个数,axis某一属性,value某一特征
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 非决策属性个数
    baseEntropy = calcShannonEnt(dataSet)#决策属性的熵
    bestInfoGain = 0.0;
    bestFeature = -1
    for i in range(numFeatures):  # 某一非决策属性
        featList = [example[i] for example in dataSet]  # 该属性的值
        uniqueVals = set(featList)  # 该属性下特征的种类
        newEntropy = 0.0
        for value in uniqueVals:    #该属性下特征的某一种类
            subDataSet = splitDataSet(dataSet, i, value)#某一属性下某一特征的个数
            prob = len(subDataSet) / float(len(dataSet))#某一属性下某一特征的概率
            newEntropy += prob * calcShannonEnt(subDataSet)#某一属性下某一特征的熵,并求和
        infoGain = baseEntropy - newEntropy  # 计算某一属性的信息增益
        if (infoGain > bestInfoGain):  #求最大增益的属性
            bestInfoGain = infoGain  # if better than current best, set to best
            bestFeature = i
    return bestFeature  # 返回该属性

def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]#决策类
    if classList.count(classList[0]) == len(classList):#递归终止条件,决策类单一

        return classList[0]  # stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1:   #递归终止条件,特征里只剩决策类
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])

    featValues = [example[bestFeat] for example in dataSet]

    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]     #树字典的第一个键值
    secondDict = inputTree[firstStr]   #树字典的第一个键值的值

    featIndex = featLabels.index(firstStr)#树字典第一个键值在标签列表里的引索值
    key = testVec[featIndex]            #树字典第一个键值在标签列表里的引索值
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):   #判断valueOfFeat是否为字典
        classLabel = classify(valueOfFeat, featLabels, testVec)#如果是字典迭代
    else:
        classLabel = valueOfFeat
    return classLabel


def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

# 执行分类
mydat,labels=createDataSet() # mydat,lables相当于全局变量
myTree=createTree(mydat,labels) # 树字典
labels = ['weather', 'temperature','humidity', 'wind']
print classify(myTree,labels,[1,1,2,0])  # 输出预测类型
tree_plot.createPlot(myTree)

画图源代码1—2 来自篇博客

# -*- coding: utf-8 -*-
"""
绘制树节点
Created on Thu Aug 10 10:37:02 2017
@author: LiLong
"""
#import decision_tree.py
import matplotlib.pyplot as plt



# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle="<-") # 定义箭头属性,也可以是<->,效果就变成双箭头的了


# 绘制结点文本和指向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #nodeTxt为要显示的文本,xytext是文本的坐标,
    #xy是注释点的坐标 ,nodeType是注释边框的属性,arrowprops连接线的属性
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]  # 得到第一个键
    secondDict = myTree[firstStr] # 得到第一个键对应的值
    for key in secondDict.keys():
        # 测试节点的数据类型是否是字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key]) # 又是递归调用
        else:   numLeafs +=1
    return numLeafs # 返回叶节点数

# 获取树的层数(递归在此就像是一层一层的剥到最里面,然后再从里到外加起来)
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys(): #keys()函数得到的是key,是一个列表
        #print'key:',key
        # 测试节点的数据类型是否是字典,如果是字典说明是可以再分的,深度+1
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用,层层剥离字典
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 绘制中间文本的坐标和显示内容,即父子之间的填充文本
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  # 求中间点的横坐标
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    # 绘制出来此文本
    createPlot.ax1.text(xMid, yMid, txtString)


# 绘制树形图
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 得到叶节点的数,宽
    print 'numLeafs:',numLeafs
    depth = getTreeDepth(myTree)  # 获得树的层数,高
    firstStr = myTree.keys()[0]    # 得到第一个划分的特征
    # 计算坐标
    print 'plotTree.xOff:',plotTree.xOff
    print 'plotTree.totalW:',plotTree.totalW
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
                plotTree.yOff)
    #print 'cntrPt:',cntrPt
    # cntrPt是刚计算的坐标,parentPt是父节点坐标,nodeTxt目前为空字符
    plotMidText(cntrPt, parentPt, nodeTxt) # 绘制连接线上的文本
    plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制树节点
    secondDict = myTree[firstStr] # 下一级字典,即下一层
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 纵坐标降低

    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':   # 如果是树节点
            plotTree(secondDict[key],cntrPt,str(key))  #递归,绘制
        else:   #如果是一个叶节点,就绘制出来
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 定x坐标
            # secondDict[key]叶节点文本,(plotTree.xOff, plotTree.yOff)箭头指向的坐标
            # cntrPt注释(父节点)的坐标
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制文本及连线
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  # 绘制父子填充文本
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# 预留树信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#  Axis为坐标轴,Label为坐标轴标注。Tick为刻度线,ax是坐标系区域
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    # 横纵坐标轴的刻度线,应该为空,加上范围后,父子间的节点连线的填充文本位置错乱
    axprops = dict(xticks=[], yticks=[]) # {'xticks': [], 'yticks': []}
    # createPlot.ax1创建绘图区,无边框,无刻度值
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    #createPlot.ax1 = plt.subplot(111, frameon=False)
    # 计算树形图的全局变量,用于计算树节点的摆放位置,将树绘制在中心位置
    plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalW保存的是树的宽
    plotTree.totalD = float(getTreeDepth(inTree)) # plotTree.totalD保存的是树的高
    plotTree.xOff = -0.5/plotTree.totalW # 决策树起始横坐标
    plotTree.yOff = 1.0  # 决策树的起始纵坐标
    plotTree(inTree, (0.5,1.0), '') # 绘制树形图
    plt.show() # 显示

四 实验结果

ID3—决策树算法_第5张图片

你可能感兴趣的:(python学习,机器学习,决策树id3算法,python,机器学习)