机器学习之构造决策树

前言:

        本节使用数据依旧是之前生成的三种球类数据,刚进入这篇文章的小伙伴可以回头看下。链接如下:

         机器学习入门之k近邻算法_俺从头开始的博客-CSDN博客

信息营地:

决策树:

        百度百科讲决策树:“决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。

        决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。”

        本质上来讲,决策树还是一个分类模型,所以它的中心工作是利用一种度量手段来区分各个类别的数据。

信息熵:

        一位二十世纪的天才——克劳德·香农,提出了一种名叫“熵”的度量标准,至今依旧被广泛应用于信息领域。

熵公式:

                                        H=-\sum_{i=1}^{n}p(x_{i})log_{2}p(x_{i})

其中,p(x_{i})为分类的概率。

信息增益:

        与信息熵一起诞生的产物。

计算公式:

                                Gain(D, a)=H(D)-\sum_{v=1}^{V} \frac{|D^{v}|}{|D|}H(D^{v})

        一个系统越是有序,信息熵就越低,一个系统越是混乱,信息熵就越高,所以信息熵被认为是一个系统有序程度的度量。

动手实践:

主要就是敲代码实现的过程了,展示如下:

from math import log
import operator
import matplotlib.pyplot as plt
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 计算实例总数
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 键值是最后一列数值
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 # 将不存在的类别加入字典
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)  # 香农公式
    return shannonEnt


# 划分数据集
def splitDataSet(dataSet, axis, value):  # dataSet:待划分的数据集,axis:划分数据集特征,value:需要返回的特征值
    retDataSet = []  # 防止破坏原始数据
    for featVec in dataSet:
        if featVec[axis] == value:  # 符合要求
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)  # 抽取
    return retDataSet


# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 创建唯一的分类标签列表
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy  # 计算每种方式的信息熵
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i  # 计算最好的信息增益
    return bestFeature


# 取最大值标签
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), \
                                key=operator.itemgetter(1), reverse=True)  # 根据字典的值降序排列
    return sortedClassCount[0][0]


# 创建树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]  # classList只剩下一种值
    if len(dataSet[0]) == 1:  # dataSet中属性使用完毕,但没有分配完毕
        return majorityCnt(classList)  # 取数量最多作为分类
    bestFeat = chooseBestFeatureToSplit(dataSet)
    labels2 = labels.copy()
    bestFeatLabel = labels2[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels2[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels2[:]  # 剩余属性列表
        myTree[bestFeatLabel][value] = createTree(splitDataSet \
                                                      (dataSet, bestFeat, value), subLabels)
    return myTree


#导入数据
def sentDataSet(filename):
    with open(filename, 'r', encoding='utf-8') as file:
        arrayOLines = file.readlines() #列表型
    numberOfLines = len(arrayOLines)
    dataSet = numpy.zeros((numberOfLines, 4))
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t\t')
        dataSet[index, :] = listFromLine[0: 5]
    labels = ['圆周长', '重量', '材料', '花纹']
    return dataSet, labels

myDat, labels = sentDataSet('./data1')
myTree = createTree(myDat, labels)
print(myTree)

结果如下:

机器学习之构造决策树_第1张图片数据集采用了上一节中已生成的数据 ,所以还是分类球类的问题。

主要遇到的问题是,连续的数据没有进行处理,以至于每一组数据都成了一个单独的类别。连续数组的处理将在下一节提到。

导入数据主要就是将之前生成的数据集导入进来,方便后续操作。这一块代码可能有些简陋,毕竟博主水平还有限,不像之前的代码可以照着书敲【手动狗头】。不过嘛,代码毕竟是调通了,皆大欢喜,放心食用。

当然,这个问题也不能一直放着。于是乎,博主决定去掉数据集里的连续型变量,只用后两列数据进行分类,结果如下:

至于代码嘛,倒也不用大改,切片取出数据时修改一下参数即可。

如下:

# 导入数据
def sentDataSet(filename):
    with open(filename, 'r', encoding='utf-8') as file:
        arrayOLines = file.readlines()  # 列表型
    dataSet = []
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t\t')
        dataSet.append(listFromLine[2:5])  # 修改切片
    labels = ['材料', '花纹']  # 修改这里
    return dataSet, labels

修改部分已注释。

可视化树:

这一块是希望能过够将分类好的树可视化的输出,也就是直观的看到这棵树。

代码如下:

# 用Matplotlib注解绘制树形图

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="square", fc="0.8")  # boxstyle文本框样式、fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8")  # 叶子节点
arrow_args = dict(arrowstyle="<-")  # 定义箭头


# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 此函数执行绘制功能
    # createPlot.ax1是表示: ax1是函数createPlot的一个属性
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,
                            textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs = 0  # 初始化
    firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)
    secondDict = myTree[firstStr]  # 获得value值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[key])  # 递归调用函数
        else:
            numLeafs += 1
    return numLeafs


# 获取树的深度
def getTreeDepth(myTree):
    maxDepth = 0  # 初始化
    firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)
    secondDict = myTree[firstStr]  # 获得value值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典
            thisDepth = 1 + getTreeDepth(secondDict[key])  # 递归调用
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


# 画树
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 获取树高
    depth = getTreeDepth(myTree)  # 获取树深度
    firstStr = list(myTree.keys())[0]  # 这个节点的文本标签
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)  # plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 标记子节点属性
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 减少y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# 绘制决策树
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 创建一个新图形
    fig.clf()  # 清空绘图区
    font = {'family': 'MicroSoft YaHei'}
    plt.rc("font", **font)
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


Createplot = createPlot(myTree)

运行结果:

机器学习之构造决策树_第2张图片

         材料这个分类分的稀烂,但这主要是因为这个类别本身就不适合分类。但由于去掉了前两列的数据,防止生成树过于单调,我最后还是决定将它加上。

        代码问题上注意下面这一列:机器学习之构造决策树_第3张图片

 本身就是字体的选择,但需要你的电脑上有这个可以编译的字体,这里推荐使用

font = {'family': 'SimHei'},常规电脑应该还是没有问题的。

下一节将会对树进行优化,也有连续性变量的处理方式,欢迎追订哦!

你可能感兴趣的:(决策树,人工智能)