python ID3 决策树 代码

 代码参考了python实现ID3决策树分类算法_aoanng的博客-CSDN博客_id3算法python实现

有所精简 


'''
function:ID3决策树生成算法
author:baomi
date: 2021/11/01
reference: https://blog.csdn.net/colourful_sky/article/details/82056125
'''

import math

def splitDataSet(dataSet, i, value):
    '''
    返回数据集dataSet中,去掉第i列属性值为value的实例后形成的新的数据集
    '''
    retDataSet = []
    for x in dataSet:
        if x[i] == value:
            temp = x[:]
            temp.pop(i)
            retDataSet.append(temp)
    return retDataSet


def calcEntropy(dataSet):
    '''
    计算一个数据集的熵
    '''
    labelDict = {}  # 数据集的标签-该标签总个数
    for x in dataSet:
        label = x[-1]
        if label not in labelDict.keys():
            labelDict[label] = 0
        labelDict[label] += 1

    n = len(dataSet)
    retEntropy = 0.0
    for key in labelDict:
        p = float(labelDict[key]) / n  # 计算标签概率
        retEntropy -= p * math.log(p, 2)

    return retEntropy


def calcInfoGain(dataSet, i):
    '''
    计算对数据集dataSet,选定第i列特征时所获得的信息增益
    '''
    preEntropy = calcEntropy(dataSet)

    postEntropy = 0.0
    featureSet = set([x[i] for x in dataSet])  # 得到i列特征所有特征值的集合
    for feature in featureSet:  # 以feature为筛选条件,计算筛选后的数据集熵
        subDataSet = splitDataSet(dataSet, i, feature)
        subDataSetEntropy = calcEntropy(subDataSet)
        p = len(subDataSet)/len(dataSet)
        postEntropy += p*subDataSetEntropy

    return preEntropy-postEntropy

def getMaxInfoGainNode(dataSet, featureNameList):
    '''
    featureNameList是dataSet中各特征名称
    该函数返回两种结果:
    熵为0时,返回标签,类型为str
    熵不为0时,返回具有最大信息增益的特征的索引号,特征名,以及最大信息增益
    '''
    dataSetEntropy = calcEntropy(dataSet)
    if dataSetEntropy == 0:
        return dataSet[0][-1]  # 数据集熵为0,说明标签都相同,直接将该标签返回

    featureNum = len(featureNameList)

    maxInfoGain = 0
    maxInfoGainIndex = 0
    for i in range(0, featureNum): #遍历所有特征,获得具有最大信息增益的特征索引号
        infoGain = calcInfoGain(dataSet, i)
        if infoGain > maxInfoGain:
            maxInfoGain = infoGain
            maxInfoGainIndex = i

    return maxInfoGainIndex, featureNameList[maxInfoGainIndex], maxInfoGain


def createID3Tree(dataSet, featureNameList):
    '''
    该函数返回一个结点
    如果dataSet熵为0,那么返回dataSet中类标签,此标签唯一
    否则,返回一个字典,该字典的key为dataSet选出的最优特征名,value又为一个字典,
    value字典的key为最优特征的特征值名,value字典的value又为一个字典.....
    '''
    maxInfoGainNode = getMaxInfoGainNode(dataSet, featureNameList)
    if type(maxInfoGainNode) == str:
        return maxInfoGainNode

    nodeIndex, nodeName = maxInfoGainNode[0], maxInfoGainNode[1]
    ret = {}
    ret[nodeName] = {}

    featureSet = set([x[nodeIndex] for x in dataSet])
    for feature in featureSet:
        subDataSet = splitDataSet(dataSet, nodeIndex, feature)
        newFeatNameList = featureNameList[:]
        newFeatNameList.pop(nodeIndex)
        childTree = createID3Tree(subDataSet, newFeatNameList) #对以最大信息增益作为特征筛选后的子数据集进行递归调用
        ret[nodeName][feature] = childTree

    return ret

dataSet = [['青年', '否', '否', '一般', '拒绝'],
           ['青年', '否', '否', '好', '拒绝'],
           ['青年', '是', '否', '好', '同意'],
           ['青年', '是', '是', '一般', '同意'],
           ['青年', '否', '否', '一般', '拒绝'],
           ['中年', '否', '否', '一般', '拒绝'],
           ['中年', '否', '否', '好', '拒绝'],
           ['中年', '是', '是', '好', '同意'],
           ['中年', '否', '是', '非常好', '同意'],
           ['中年', '否', '是', '非常好', '同意'],
           ['老年', '否', '是', '非常好', '同意'],
           ['老年', '否', '是', '好', '同意'],
           ['老年', '是', '否', '好', '同意'],
           ['老年', '是', '否', '非常好', '同意'],
           ['老年', '否', '否', '一般', '拒绝'], ]

featureNameList = ['年龄', '有工作', '有房子', '信贷情况']


ID3Tree = createID3Tree(dataSet, featureNameList)
print(ID3Tree)

输出结果:

对结果用matplotlib绘图

代码来自Matplotlib绘制树形图_wancongconghao的博客-CSDN博客_matplotlib 树状图 

#绘制树形图

import matplotlib.pyplot as plt

decision_node = dict(boxstyle="sawtooth",fc="0.8")

leaf_node = dict(boxstyle="round4",fc="0.8")

arrow_args = dict(arrowstyle="<-")

#获取树的叶子结点个数(确定图的宽度)

def get_leaf_num(tree):

    leaf_num = 0

    first_key = list(tree.keys())[0]

    next_dict = tree[first_key]

    for key in next_dict.keys():

        if type(next_dict[key]).__name__=="dict":

            leaf_num +=get_leaf_num(next_dict[key])

        else:

            leaf_num +=1

    return leaf_num

#获取数的深度(确定图的高度)

def get_tree_depth(tree):

    depth = 0

    first_key = list(tree.keys())[0]

    next_dict = tree[first_key]

    for key in next_dict.keys():

        if type(next_dict[key]).__name__ == "dict":

            thisdepth = 1+ get_tree_depth(next_dict[key])

        else:

            thisdepth = 1

        if thisdepth>depth: depth = thisdepth

    return depth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):

    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',

                            xytext=centerPt, textcoords='axes fraction',

                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

#在父子节点间填充文本信息

def plotMidText(cntrPt, parentPt, txtString):

    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]

    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]

    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):

    numLeafs = get_leaf_num(myTree)

    depth = get_tree_depth(myTree)

    firstStr = list(myTree.keys())[0]

    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    plotMidText(cntrPt, parentPt, nodeTxt)

    plotNode(firstStr, cntrPt, parentPt, decision_node)

    secondDict = myTree[firstStr]

    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD

    for key in secondDict.keys():

        if type(secondDict[

                    key]).__name__ == 'dict':

            plotTree(secondDict[key], cntrPt, str(key))

        else:

            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW

            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)

            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD



 

def createPlot(inTree):

    fig = plt.figure(1, facecolor='white')

    fig.clf()

    axprops = dict(xticks=[], yticks=[])

    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)

    plotTree.totalW = float(get_leaf_num(inTree))

    plotTree.totalD = float(get_tree_depth(inTree))

    plotTree.xOff = -0.5 / plotTree.totalW

    plotTree.yOff = 1.0

    plotTree(inTree, (0.5, 1.0), '')

    plt.show()

plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签

plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

tree = {'有房子': {'否': {'有工作': {'否': '拒绝', '是': '同意'}}, '是': '同意'}}

createPlot(tree)

结果:

python ID3 决策树 代码_第1张图片

你可能感兴趣的:(决策树,python,机器学习)