《机器学习实战》(二)---决策树

《机器学习实战(二)》——决策树

    • 决策树介绍
    • 决策树思想
    • 信息增益
    • example_海洋生物数据
    • example_数据可视化
    • 其他决策树的计算方法
    • 后记

决策树介绍

决策树也是非常常见的算法,很多经常见到的例子中就有着决策树的身影,在你送给你女朋友礼物的时候,可能会有如下对话。

女朋友:是口红吗?
我:不是
女朋友:是香水吗?
我:不是
女朋友:是吃的,穿的,还是用的?
我:是穿的
女朋友:是鞋子吗?
我:对的

到此,猜测完毕,决策树就是这样通过一层层的决策,最终输出一个最为合理的判断。

决策树思想

面对一堆非常多的数据,决策树的想法是根据不同的特征划分成不同的分类,从宏观上来看,是通过分类将一堆数据从无序变有序的过程。
决策树的思想也是比较好理解的,就是通过不断的决策,划分类别,得到正确的结果,那么,面对很多数据以及很多特征的时候,那么我们会选择哪个特征作为第一个分类特征呢。怎么选择才会使分类效果最好,分类最快呢?这就是决策树最关键的地方,我们要通过量化的方法,计算每次划分所带来的信息增益,通过比较信息增益的大小选择用哪个特征作为分类。

信息增益

熵:表示随机变量的不确定性。熵越大说明变量越混乱

条件熵:在一个条件下,随机变量的不确定性。

信息增益:熵 - 条件熵

信息增益虽然叫增益,但其实是一个“减法”的过程。为了方便理解,先不讲数学计算,我们考虑一个不是很恰当的例子。从1-10十个数,让我们划分成两个类别,A方式是根据奇偶划分,B方式是根据个位数/十位数划分,我们大概率会选择A方式,而这背后就遵循着信息增益的原理。从感觉上来看,A方式划分后数据的混乱程度相比原来会减小许多,而B的混乱程度依然很大,条件熵也是如此。
信息增益=熵 - 条件熵,初始的熵都一样,A条件熵小,A的信息增益大,我们选择A方式划分也是基于A的信息增益大而选择的。

从数学的角度上来考虑:
我们首先要先计算原始数据的熵,其中 D 表示训练数据集,c 表示数据类别数,Pi 表示类别 i 样本数量占所有样本的比例。Info(D)是初始熵

对应数据集 D,选择特征 A 作为决策树判断节点时,在特征 A 作用后的信息熵的为 Info(D),其中k是被分为k个类别,Dj是划分后第j个类别的数据集。InfoA(D)是特征A作用后的条件熵,InfoA(D)越小说明分类后混乱程度越小,条件熵越小。计算公式如下:
在这里插入图片描述

信息增益即为:
在这里插入图片描述
我们每次通过某一特征划分后都会计算出本次得到的信息增益,最合适的特征就是条件熵最小、即信息增益最大的那个。

example_海洋生物数据

《机器学习实战》(二)---决策树_第1张图片
这是《机器学习实战》上的例子,下面我们首先对其进行决策树的构建,再将其可视化出来。

# version:python3.7.3
# author:hty
# date:2020.5.3
from math import log
import treePlotter

def calcShannonEnt(dataSet):
	'''
	function:计算熵
	input:数据集
	output:熵
	'''
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
    	# 获取每个数据对应的结果类别
        currentLabel = featVec[-1]
        # 建立一个字典,key是类别,value是类别出现的次数
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
    	# 每个类别出现的概率
        prob = float(labelCounts[key])/numEntries
        # 计算公式
        shannonEnt -= prob*log(prob, 2)
    return shannonEnt
    

def createDataSet():
	'''
	function:创建数据集
	input:None
	output:
		dataSet:数据集
		labels:标签
	'''
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def splitDataSet(dataSet, axis, value):
	'''
	function:划分数据集
	input:
		dataSet:待划分数据集
		axis:划分数据集的特征
		value:特征值
	output:划分后的数据集
	'''
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
        	# 特征值之前的部分
            reducedFeatVec = featVec[:axis]
            # 再加上特征值后的部分,刚好把特征值规避掉
            reducedFeatVec.extend(featVec[axis+1:])
            'print(reducedFeatVec)'
            retDataSet.append(reducedFeatVec)
    # print('retDataSet:', retDataSet)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
	'''
	function:选择划分最好的特征
	input:数据集
	output:最合适的特征
	'''
	# 特征数量
    numFeatures = len(dataSet[0]) - 1
    # 计算初始熵
    baseEntropy = calcShannonEnt(dataSet)
    # 初始最大增益为0,最合适特征为-1
    bestInfoGain = 0.0; bestFeature = -1
    # 开始计算每个特征划分后产生的信息增益
    for i in range(numFeatures):
    	# 特征对应的特征值
        featList = [example[i] for example in dataSet]
        # 集合,删除重复特征值
        uniqueVals = set(featList)
        # 条件熵初始为0
        newEntropy = 0.0
        # 对第i个特征的每一个value
        for value in uniqueVals:
        	# 根据i,value逐个划分数据集
            subDataSet = splitDataSet(dataSet, i, value)
            # 计算新的信息熵
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        # 计算信息增益
        infoGain = baseEntropy-newEntropy
        # 如果本次信息增益大于之前的最大信息增益,那么更新最大信息增益和最合适分类特征
        if (infoGain > bestInfoGain):
            bestInfoGain =infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
	'''
	function:如果每个点数据就只有一个(yes or no)的结果,那么就输出结果出现次数的结果
	'''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),
    key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
	'''
	function:主函数,创建决策树
	input:
		dataSet:数据集
		labels:标签
	output:决策树
	'''
    classList = [example[-1] for example in dataSet]
    # 如果类别只有一种,直接输出结果
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 如果数据只有一个数据信息
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 得到第一个最佳分类特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
	# 删除这个最佳分类特征,以便接下来继续调用这个函数
    del(labels[bestFeat])
    # 得到每个数据最佳特征对应的信息
    featValues = [example[bestFeat] for example in dataSet]
    # 建立集合
    uniqueVals = set(featValues)
    for value in uniqueVals:
    	# 剩下的标签是新的标签
        subLabels = labels
        # 继续调用createTree,先根据最佳分类特征划分数据集,并使用删除最佳分类特征的标签,
        myTree[bestFeatLabel][value] = createTree(splitDataSet\
                            (dataSet, bestFeat, value),subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
	'''
	function:输入一个数据判断它属于哪一类
	input:
		inputTree:前面已经形成的树
		featLabels:树的标签
		testVec:待分类的数据
	'''
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    # 之前在createTree这个函数里我们对labels标签处理了,这里还要用之前的labels标签,得到最佳分类标签
    featIndex = featLabels.index(firstStr)
    # 对子一层节点判断
    for key in secondDict.keys():
        if testVec[featIndex] == key:
        	# 如果对应键值还是字典,继续调用这个函数
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            # 如果不是了,那就输出结果
            else:
                classLabel = secondDict[key]
    return classLabel

'''
下面这两个函数是储存树结构和读取树结构,
其实直接open就能完成目的,但是pickle读取和写入速度更快!
'''
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    inputTree = str(inputTree)
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, 'rb+')
    return pickle.load(fr)
    
    
myDat, labels = createDataSet()
myTree = createTree(myDat, labels)
print(myTree)


example_数据可视化

《机器学习实战》(二)---决策树_第2张图片

有一说一,这个图画的是真的丑,不过它倒是介绍了一个非常重要的库,matplotlib,(matlab表示不服,有本事画3D图),这个库比较简单,而且基本上满足绝大部分需求了,如果还没有学过的话建议还是要仔细看看,这一部分可以参考这篇文章,关于图为什么这么画,讲的非常详细了。https://blog.csdn.net/liyuefeilong/article/details/48244529
下面就直接放代码了。

import matplotlib.pyplot as plt

# 定义叶子节点和箭头样式
decisionNode = dict(boxstyle = 'sawtooth',fc= '0.8')
leafNode = dict(boxstyle = 'round4', fc = '0.8')
arrow_args = dict(arrowstyle = '<-')


def getNumLeafs(myTree):
	'''
	function:获得叶子数,便于分配横向空间
	input:树(字典)
	output:叶子数
	'''
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
	'''
	function:获得树的高度(层数),便于分配纵向空间
	input:树(字典)
	output:高度
	'''
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth : maxDepth = thisDepth
    return maxDepth
    
#下面这里比较难理解,可以用手画画图找找感觉,但是感觉没什么太大价值,就是找位置。

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
	'''
	function:绘制带箭头的注解
	input:
		nodeTxt:文本信息
		centerPt:子节点坐标信息
		parentPt:父节点坐标信息
		nodeType:箭头类型
	'''
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def plotMidText(cntrPt, parentPt, txtString):
	'''
	function:在父&子节点中的箭头上填充文本信息
	input:
		cntrPt:子节点位置信息
		parentPt:父节点位置信息
		txtString:文本信息
	'''
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
	'''
	function:绘制树,调用其他函数
	'''
	# 得到叶子数
    numLeafs = getNumLeafs(myTree)
    # 得到纵向高度
    depth = getTreeDepth(myTree)
    # 第一个节点
    firstStr = list(myTree.keys())[0]
    # 子节点的位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\
              plotTree.yOff)
	# 在父&子节点中的箭头上填充文本信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    # 绘制箭头
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
	# yOff是为下一层子节点y轴高度做准备
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 对每层节点开始循环
    for key in secondDict.keys():
    	# 如果对应键值是字典,继续调用这个函数
        if type(secondDict[key]).__name__=='dict': 
            plotTree(secondDict[key],cntrPt,str(key))        
        # 如果是叶节点,可以画出来了,内容同上
        else:   
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
	'''
	function:主函数,调用其他函数进行绘图
	'''
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    
# 方便调试
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]


createPlot(retrieveTree(0))


其他决策树的计算方法

以上采用信息增益的算法叫做ID3算法,除此之外,还有C4.5,CART算法等等。但他们都被成为决策树,大体的思路是相同的,差别在于计算方式。
C4.5的生成算法采用信息增益比,即信息增益除训练数据集的熵。
CART算法更复杂一些,后面的章节还会再提到,这里就不再说了(其实是忘了-_-!)

后记

决策树在构建过程中,常常会产生过拟合的现象,需要对其剪枝,以简化决策树,决策树的剪枝,就是从生成的树上减去一些叶节点,并将其父节点作为新的叶节点,从而简化决策树。
最后写了半天,感叹于自己语言表达能力的薄弱,说来说去也就是照着那两本书的思路来,也感叹于人家怎么讲的这么明白,比我内容多,比我语言精炼,本来想写出来,万一有“后来者”学习到这里可以提供一点帮助,还是自己自作多情了,还是直接看书比较好,如果代码有看不懂的倒是可以看看我的注释。。。
下面是参考的文章/书:
《机器学习实战》
《统计学习方法》
https://www.zhihu.com/question/22104055
https://www.ibm.com/developerworks/cn/analytics/library/ba-1507-decisiontree-algorithm/index.html

你可能感兴趣的:(机器学习实战,决策树,python,机器学习)