决策树也是非常常见的算法,很多经常见到的例子中就有着决策树的身影,在你送给你女朋友礼物的时候,可能会有如下对话。
女朋友:是口红吗?
我:不是
女朋友:是香水吗?
我:不是
女朋友:是吃的,穿的,还是用的?
我:是穿的
女朋友:是鞋子吗?
我:对的
到此,猜测完毕,决策树就是这样通过一层层的决策,最终输出一个最为合理的判断。
面对一堆非常多的数据,决策树的想法是根据不同的特征划分成不同的分类,从宏观上来看,是通过分类将一堆数据从无序变有序的过程。
决策树的思想也是比较好理解的,就是通过不断的决策,划分类别,得到正确的结果,那么,面对很多数据以及很多特征的时候,那么我们会选择哪个特征作为第一个分类特征呢。怎么选择才会使分类效果最好,分类最快呢?这就是决策树最关键的地方,我们要通过量化的方法,计算每次划分所带来的信息增益,通过比较信息增益的大小选择用哪个特征作为分类。
熵:表示随机变量的不确定性。熵越大说明变量越混乱
条件熵:在一个条件下,随机变量的不确定性。
信息增益:熵 - 条件熵
信息增益虽然叫增益,但其实是一个“减法”的过程。为了方便理解,先不讲数学计算,我们考虑一个不是很恰当的例子。从1-10十个数,让我们划分成两个类别,A方式是根据奇偶划分,B方式是根据个位数/十位数划分,我们大概率会选择A方式,而这背后就遵循着信息增益的原理。从感觉上来看,A方式划分后数据的混乱程度相比原来会减小许多,而B的混乱程度依然很大,条件熵也是如此。
信息增益=熵 - 条件熵,初始的熵都一样,A条件熵小,A的信息增益大,我们选择A方式划分也是基于A的信息增益大而选择的。
从数学的角度上来考虑:
我们首先要先计算原始数据的熵,其中 D 表示训练数据集,c 表示数据类别数,Pi 表示类别 i 样本数量占所有样本的比例。Info(D)是初始熵
对应数据集 D,选择特征 A 作为决策树判断节点时,在特征 A 作用后的信息熵的为 Info(D),其中k是被分为k个类别,Dj是划分后第j个类别的数据集。InfoA(D)是特征A作用后的条件熵,InfoA(D)越小说明分类后混乱程度越小,条件熵越小。计算公式如下:
信息增益即为:
我们每次通过某一特征划分后都会计算出本次得到的信息增益,最合适的特征就是条件熵最小、即信息增益最大的那个。
这是《机器学习实战》上的例子,下面我们首先对其进行决策树的构建,再将其可视化出来。
# version:python3.7.3
# author:hty
# date:2020.5.3
from math import log
import treePlotter
def calcShannonEnt(dataSet):
'''
function:计算熵
input:数据集
output:熵
'''
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
# 获取每个数据对应的结果类别
currentLabel = featVec[-1]
# 建立一个字典,key是类别,value是类别出现的次数
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# 每个类别出现的概率
prob = float(labelCounts[key])/numEntries
# 计算公式
shannonEnt -= prob*log(prob, 2)
return shannonEnt
def createDataSet():
'''
function:创建数据集
input:None
output:
dataSet:数据集
labels:标签
'''
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def splitDataSet(dataSet, axis, value):
'''
function:划分数据集
input:
dataSet:待划分数据集
axis:划分数据集的特征
value:特征值
output:划分后的数据集
'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# 特征值之前的部分
reducedFeatVec = featVec[:axis]
# 再加上特征值后的部分,刚好把特征值规避掉
reducedFeatVec.extend(featVec[axis+1:])
'print(reducedFeatVec)'
retDataSet.append(reducedFeatVec)
# print('retDataSet:', retDataSet)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
'''
function:选择划分最好的特征
input:数据集
output:最合适的特征
'''
# 特征数量
numFeatures = len(dataSet[0]) - 1
# 计算初始熵
baseEntropy = calcShannonEnt(dataSet)
# 初始最大增益为0,最合适特征为-1
bestInfoGain = 0.0; bestFeature = -1
# 开始计算每个特征划分后产生的信息增益
for i in range(numFeatures):
# 特征对应的特征值
featList = [example[i] for example in dataSet]
# 集合,删除重复特征值
uniqueVals = set(featList)
# 条件熵初始为0
newEntropy = 0.0
# 对第i个特征的每一个value
for value in uniqueVals:
# 根据i,value逐个划分数据集
subDataSet = splitDataSet(dataSet, i, value)
# 计算新的信息熵
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
# 计算信息增益
infoGain = baseEntropy-newEntropy
# 如果本次信息增益大于之前的最大信息增益,那么更新最大信息增益和最合适分类特征
if (infoGain > bestInfoGain):
bestInfoGain =infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
'''
function:如果每个点数据就只有一个(yes or no)的结果,那么就输出结果出现次数的结果
'''
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),
key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
'''
function:主函数,创建决策树
input:
dataSet:数据集
labels:标签
output:决策树
'''
classList = [example[-1] for example in dataSet]
# 如果类别只有一种,直接输出结果
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据只有一个数据信息
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 得到第一个最佳分类特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
# 删除这个最佳分类特征,以便接下来继续调用这个函数
del(labels[bestFeat])
# 得到每个数据最佳特征对应的信息
featValues = [example[bestFeat] for example in dataSet]
# 建立集合
uniqueVals = set(featValues)
for value in uniqueVals:
# 剩下的标签是新的标签
subLabels = labels
# 继续调用createTree,先根据最佳分类特征划分数据集,并使用删除最佳分类特征的标签,
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet, bestFeat, value),subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
'''
function:输入一个数据判断它属于哪一类
input:
inputTree:前面已经形成的树
featLabels:树的标签
testVec:待分类的数据
'''
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
# 之前在createTree这个函数里我们对labels标签处理了,这里还要用之前的labels标签,得到最佳分类标签
featIndex = featLabels.index(firstStr)
# 对子一层节点判断
for key in secondDict.keys():
if testVec[featIndex] == key:
# 如果对应键值还是字典,继续调用这个函数
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
# 如果不是了,那就输出结果
else:
classLabel = secondDict[key]
return classLabel
'''
下面这两个函数是储存树结构和读取树结构,
其实直接open就能完成目的,但是pickle读取和写入速度更快!
'''
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
inputTree = str(inputTree)
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb+')
return pickle.load(fr)
myDat, labels = createDataSet()
myTree = createTree(myDat, labels)
print(myTree)
有一说一,这个图画的是真的丑,不过它倒是介绍了一个非常重要的库,matplotlib,(matlab表示不服,有本事画3D图),这个库比较简单,而且基本上满足绝大部分需求了,如果还没有学过的话建议还是要仔细看看,这一部分可以参考这篇文章,关于图为什么这么画,讲的非常详细了。https://blog.csdn.net/liyuefeilong/article/details/48244529
下面就直接放代码了。
import matplotlib.pyplot as plt
# 定义叶子节点和箭头样式
decisionNode = dict(boxstyle = 'sawtooth',fc= '0.8')
leafNode = dict(boxstyle = 'round4', fc = '0.8')
arrow_args = dict(arrowstyle = '<-')
def getNumLeafs(myTree):
'''
function:获得叶子数,便于分配横向空间
input:树(字典)
output:叶子数
'''
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
'''
function:获得树的高度(层数),便于分配纵向空间
input:树(字典)
output:高度
'''
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth : maxDepth = thisDepth
return maxDepth
#下面这里比较难理解,可以用手画画图找找感觉,但是感觉没什么太大价值,就是找位置。
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
'''
function:绘制带箭头的注解
input:
nodeTxt:文本信息
centerPt:子节点坐标信息
parentPt:父节点坐标信息
nodeType:箭头类型
'''
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
'''
function:在父&子节点中的箭头上填充文本信息
input:
cntrPt:子节点位置信息
parentPt:父节点位置信息
txtString:文本信息
'''
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
'''
function:绘制树,调用其他函数
'''
# 得到叶子数
numLeafs = getNumLeafs(myTree)
# 得到纵向高度
depth = getTreeDepth(myTree)
# 第一个节点
firstStr = list(myTree.keys())[0]
# 子节点的位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\
plotTree.yOff)
# 在父&子节点中的箭头上填充文本信息
plotMidText(cntrPt, parentPt, nodeTxt)
# 绘制箭头
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# yOff是为下一层子节点y轴高度做准备
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# 对每层节点开始循环
for key in secondDict.keys():
# 如果对应键值是字典,继续调用这个函数
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
# 如果是叶节点,可以画出来了,内容同上
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
'''
function:主函数,调用其他函数进行绘图
'''
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
# 方便调试
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
createPlot(retrieveTree(0))
以上采用信息增益的算法叫做ID3算法,除此之外,还有C4.5,CART算法等等。但他们都被成为决策树,大体的思路是相同的,差别在于计算方式。
C4.5的生成算法采用信息增益比,即信息增益除训练数据集的熵。
CART算法更复杂一些,后面的章节还会再提到,这里就不再说了(其实是忘了-_-!)
决策树在构建过程中,常常会产生过拟合的现象,需要对其剪枝,以简化决策树,决策树的剪枝,就是从生成的树上减去一些叶节点,并将其父节点作为新的叶节点,从而简化决策树。
最后写了半天,感叹于自己语言表达能力的薄弱,说来说去也就是照着那两本书的思路来,也感叹于人家怎么讲的这么明白,比我内容多,比我语言精炼,本来想写出来,万一有“后来者”学习到这里可以提供一点帮助,还是自己自作多情了,还是直接看书比较好,如果代码有看不懂的倒是可以看看我的注释。。。
下面是参考的文章/书:
《机器学习实战》
《统计学习方法》
https://www.zhihu.com/question/22104055
https://www.ibm.com/developerworks/cn/analytics/library/ba-1507-decisiontree-algorithm/index.html