决策树完整代码

文章目录

    • trees.py
    • treePlotter.py

trees.py

#!/usr/bin/python
# coding:utf-8
import operator
from math import log
import treePlotter as dtPlot
from collections import Counter


def createDataSet():
    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no']
    ]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


def calcShannonEnt(dataSet):
    """
    计算给定数据集的香农熵
    :param dataSet:数据集
    :return:每一组feature下的某个分类下,香农熵的信息期望
    """
    # 第一种实现方法
    # 求list的长度,表示计算参与训练的数据量
    numEntries = len(dataSet)

    # 计算分类标签label出现的次数
    labelCounts = {}
    for featVec in dataSet:
        # 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签
        currentLabel1 = featVec[-1]
        # 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
        if currentLabel1 not in labelCounts.keys():
            labelCounts[currentLabel1] = 0
        labelCounts[currentLabel1] += 1

    # 对于label标签的占比,求出label标签的香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        # 使用所有类标签的发生频率计算类别出现的概率。
        prob = float(labelCounts[key]) / numEntries
        # 计算香农熵,以 2 为底求对数
        shannonEnt -= prob * log(prob, 2)
    # 第一种实现方式end

    # 第二种实现方式start
    # # 统计标签出现的次数
    # label_count = Counter(data[-1] for data in dataSet)
    # # 计算概率
    # probs = [p[1] / len(dataSet) for p in label_count.items()]
    # # 计算香农熵
    # shannonEnt = sum([-p * log(p, 2) for p in probs])
    # 第二种实现方式end

    return shannonEnt


def splitDataSet(dataSet, index, value):
    """
    就是依据index列进行分类,如果index列的数据等于 value,就要将 index 划分到我们创建的新的数据集中
    说白了就是通过index特征分类,并将特征从数据中消除
    :param dataSet:数据集                  待划分的数据集
    :param index:表示每一行的index列        划分数据集的特征
    :param value:表示index列对应的value值   需要返回的特征的值。
    :return:index列为value的数据集【该数据集需要排除index列】
    """

    retDataSet = []
    for featVec in dataSet:
        if featVec[index] == value:
            # [:index]表示前index行,即若 index 为2,就是取 featVec 的前 index 行
            reducedFeatVec = featVec[:index]
            # 1、使用append的时候,是将object看作一个对象,整体打包添加到music_media对象中。
            # 2、使用extend的时候,是将sequence看作一个序列,将这个序列和music_media序列合并,并放在其后面。
            reducedFeatVec.extend(featVec[index + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    """
    选择最好的特征
    :param dataSet:数据集
    :return:最优的特征列
    """
    # 求第一行有多少列的 Feature, 最后一列是label列
    numFeatures = len(dataSet[0]) - 1
    # 数据集的原始信息熵
    baseEntropy = calcShannonEnt(dataSet)
    # 最优的信息增益值, 和最优的Featurn编号
    bestInfoGain, bestFeature = 0.0, -1
    for i in range(numFeatures):
        # 获取对应的feature下的所有数据
        featList = [example[i] for example in dataSet]
        # 获取剔重后的集合,使用set对list数据进行去重
        uniqueVals = set(featList)
        # 创建一个临时的信息熵
        newEntropy = 0.0
        # 遍历某一列的value集合,计算该列的信息熵
        # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
        for value in uniqueVals:
            subDateSet = splitDataSet(dataSet, i, value)
            # 计算概率
            prob = len(subDateSet) / float(len(dataSet))
            # 计算信息熵
            newEntropy += prob * calcShannonEnt(subDateSet)
        # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
        # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
        infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

    # # -----------选择最优特征的第二种方式 start------------------------------------
    # # 计算初始香农熵
    # base_entropy = calcShannonEnt(dataSet)
    # best_info_gain = 0
    # best_feature = -1
    # # 遍历每一个特征
    # for i in range(len(dataSet[0]) - 1):
    #     # 对当前特征进行统计
    #     feature_count = Counter([data[i] for data in dataSet])
    #     # 计算分割后的香农熵
    #     new_entropy = sum(feature[1] / float(len(dataSet)) * calcShannonEnt(splitDataSet(dataSet, i, feature[0])) \
    #                    for feature in feature_count.items())
    #     # 更新值
    #     info_gain = base_entropy - new_entropy
    #     print('No. {0} feature info gain is {1:.3f}'.format(i, info_gain))
    #     if info_gain > best_info_gain:
    #         best_info_gain = info_gain
    #         best_feature = i
    # return best_feature
    # # -----------选择最优特征的第二种方式 end------------------------------------


def majorityCnt(classList):
    """
    选择出现次数最多的一个结果
    :param classList:列的集合
    :return:bestFeature 最优的特征列
    """
    # -----------majorityCnt的第一种方式 start------------------------------------
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    # 倒叙排列classCount得到一个字典集合,然后取出第一个就是结果(yes/no),即出现次数最多的结果
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
    # -----------majorityCnt的第一种方式 end------------------------------------

    # # -----------majorityCnt的第二种方式 start------------------------------------
    # major_label = Counter(classList).most_common(1)[0]
    # return major_label
    # # -----------majorityCnt的第二种方式 end------------------------------------


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
    # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。
    # count() 函数是统计括号中的值在list中出现的次数
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
    # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    # 选择最优的特征,得到最优特征对应的label含义
    bestFeat = chooseBestFeatureToSplit(dataSet)
    # 获取label的名称
    bestFeatLabel = labels[bestFeat]
    # 初始化myTree
    myTree = {bestFeatLabel: {}}
    # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
    # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
    del labels[bestFeat]
    # 取出最优列,然后它的branch做分类
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        # 求出剩余的标签label
        subLabels = labels[:]
        # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree()
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


# 测试算法:使用决策树执行分类
def classify(inputTree, featLabels, testVec):
    """
    给输入的节点,进行分类
    :param inputTree:决策树模型
    :param featLabels:Feature标签对应的名称
    :param testVec:测试输入的数据
    :return:classLabel 分类的结果值,需要映射label才能知道名称
    """
    # 获取tree的根节点对于的key值
    firstStr = list(inputTree.keys())[0]
    # 通过key得到根节点对应的value
    secondDict = inputTree[firstStr]
    # 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类
    featIndex = featLabels.index(firstStr)
    # 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    print('+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat)
    # 判断分枝是否结束: 判断valueOfFeat是否是dict类型
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel


# 使用算法:决策树的储存
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()
    # -------------- 第二种方法 start --------------
    # with open(filename, 'wb') as fw:
    #     pickle.dump(inputTree, fw)
    # -------------- 第二种方法 start --------------


def grapTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)


def fishTest():
    # 1.创建数据和结果标签
    myDat, labels = createDataSet()
    # print myDat, labels

    # 计算label分类标签的香农熵
    # calcShannonEnt(myDat)

    # # 求第0列 为 1/0的列的数据集【排除第0列】
    # print '1---', splitDataSet(myDat, 0, 1)
    # print '0---', splitDataSet(myDat, 0, 0)

    # # 计算最好的信息增益的列
    # print chooseBestFeatureToSplit(myDat)

    import copy
    myTree = createTree(myDat, copy.deepcopy(labels))
    print(myTree)

    # [1, 1]表示要取的分支上的节点位置,对应的结果值
    print(classify(myTree, labels, [1, 1]))

    # 获得树的高度
    print('树高:', get_tree_height(myTree))

    # 画图可视化展现
    dtPlot.createPlot(myTree)


def get_tree_height(tree):
    """
     Desc:
        递归获得决策树的高度
    Args:
        tree
    Returns:
        树高
    """

    if not isinstance(tree, dict):
        return 1

    child_trees = list(tree.values())[0].values()

    # 遍历子树, 获得子树的最大高度
    max_height = 0
    for child_tree in child_trees:
        child_tree_height = get_tree_height(child_tree)

        if child_tree_height > max_height:
            max_height = child_tree_height

    return max_height + 1


def ContactLensesTest():
    """
    Desc:
        预测隐形眼镜的测试代码
    Returns:
        none
    """

    # 加载隐形眼镜相关的 文本文件 数据
    fr = open('data/lenses.txt')
    # 解析数据,获得 features 数据
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    # 得到数据的对应的 Labels
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    # 使用上面的创建决策树的代码,构造预测隐形眼镜的决策树
    lensesTree = createTree(lenses, lensesLabels)
    print(lensesTree)
    # 画图可视化展现
    dtPlot.createPlot(lensesTree)


if __name__ == "__main__":
    # fishTest()

    ContactLensesTest()

treePlotter.py

#!/usr/bin/env python
# coding: utf-8
import matplotlib.pyplot as plt

# 定义文本框 和 箭头格式 【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,
                            xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def createPlot(inTree):
    # 创建一个figure的模版
    fig = plt.figure(1, facecolor='green')
    fig.clf()

    axprops = dict(xticks=[], yticks=[])
    # 表示创建一个1行,1列的图,createPlot.ax1 为第 1 个子图,
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)

    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    # 半个节点的长度;xOff表示当前plotTree未遍历到的最左的叶节点的左边一个叶节点的x坐标
    # 所有叶节点中,最左的叶节点的x坐标是0.5/plotTree.totalW(因为totalW个叶节点在x轴方向是平均分布在[0, 1]区间上的)
    # 因此,xOff的初始值应该是 0.5/plotTree.totalW-相邻两个叶节点的x轴方向距离
    plotTree.xOff = -0.5 / plotTree.totalW
    # 根节点的y坐标为1.0,树的最低点y坐标为0
    plotTree.yOff = 1.0
    # 第二个参数是根节点的坐标
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]  # 这里原来的代码稍有问题,变为list就好
    secondDict = myTree[firstStr]
    # 根节点开始遍历
    for key in secondDict.keys():
        # 判断子节点是否为dict, 不是+1
        if type(secondDict[key]) is dict:
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    # 根节点开始遍历
    for key in secondDict.keys():
        # 判断子节点是不是dict, 求分枝的深度
        if type(secondDict[key]) is dict:
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        # 记录最大的分支深度
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 函数retrieveTree()主要用于测试,返回预定义的树结构。
def retrieveTree(i):
    listOfTrees = [
        {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
        {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    ]
    return listOfTrees[i]


# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    # 获取叶子节点的数量
    numLeafs = getNumLeafs(myTree)
    # 获取树的深度
    # depth = getTreeDepth(myTree)

    # 找出第1个中心点的位置,然后与 parentPt定点进行划线
    # x坐标为 (numLeafs-1.)/plotTree.totalW/2+1./plotTree.totalW,化简如下
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # print cntrPt
    # 并打印输入对应的文字
    plotMidText(cntrPt, parentPt, nodeTxt)

    firstStr = list(myTree.keys())[0]
    # 可视化Node分支点;第一次调用plotTree时,cntrPt与parentPt相同
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    # 根节点的值
    secondDict = myTree[firstStr]
    # y值 = 最高点-层数的高度[第二个节点位置];1.0相当于树的高度
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        # 判断该节点是否是Node节点
        if type(secondDict[key]) is dict:
            # 如果是就递归调用[recursion]
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # 如果不是,就在原来节点一半的地方找到节点的坐标
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            # 可视化该节点位置
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 并打印输入对应的文字
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

你可能感兴趣的:(机器学习代码,决策树完整代码)