目录
一、什么是决策树
1、决策树的定义
2、决策树模型
3、决策树学习
二、信息增益
三、决策树的构造及代码
一、背景前提:
二、数据处理
三、训练算法
五、实验总结
百度百科决策树的定义:
决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。
决策树模型是一种描述对实例进行分类的树形结构。其由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点(internal node)和叶结点(leaf node)。内部结点表示一个特征或属性(features),叶结点表示一个类(labels)。
决策树学习通常包括 3 个步骤:特征选择、决策树的生成和决策树的修剪。
假设集美大学本部和诚毅学院都有开设智能科学与技术专业,且两个学校之间的快递驿站是同一个地方(两校离得近且互通)。 如果在快递驿站得知一名同学的以下信息,你能否知道他是否是集美大学本部的智能科学与技术专业的?
信息一:是否住在五社区
信息二:是否学习C语言
信息三:是否学习机器学习
(注:集美大学本部的智能科学与技术专业的同学都住在五社区,诚毅学院的智能科学与技术专业的同学不住五社区。只有智能科学与技术专业同时开设C语言和机器学习 ,有的其他专业开设C语言,但没开设机器学习;有的专业开设机器学习,但没开设C语言。)
根据上述信息,我们可以画出表格如下:
序号 | 是否住五社区 | 是否学C语言 | 是否学机器学习 | 是否是集大本部智能专业的学生 |
1 | 是 | 是 | 是 | 是 |
2 | 是 | 是 | 否 | 否 |
3 | 是 | 否 | 是 | 否 |
4 | 是 | 否 | 否 | 否 |
5 | 否 | 是 | 是 | 否 |
6 | 否 | 是 | 否 | 否 |
7 | 否 | 否 | 是 | 否 |
8 | 否 | 否 | 否 | 否 |
def createDataSet():
dataSet = [[1, 1, 1, 'yes'],
[1, 1, 0, 'no'],
[1, 0, 1, 'no'],
[1, 0, 0, 'no'],
[0, 1, 1, 'no'],
[0, 1, 0, 'no'],
[0, 0, 1, 'no'],
[0, 0, 0, 'no']]
labels = ['live in 5', 'learn c', "learn ML"]
return dataSet, labels
计算给定数据集的香农熵的函数如下:
from math import log
def calcShannonEnt(dataSet):
'''
概率计算就是计算数据出现的频率
第一步:统计总体数据数
第二步:求出每个标签出现次数,这里需要用到字典的映射关系
第三步:遍历求出香农熵
熵越高,则混合额数据也就越多,数据的混乱程度也就越高
:param dataSet:
:return: shannonEnt
'''
numEntries = len(dataSet) #计算数据集中实例总数
labelCounts = {} #创建一个字典 键值为数据标签
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 #没有这个标签就拓展字典并将当前键值加入字典
labelCounts[currentLabel] += 1 #统计键值出现的次数
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #计算对应标签出现的概率
shannonEnt -= prob * log(prob,2) #log后为底数
return shannonEnt
按照给定特征划分数据集,将指定特征的特征值等于 value 的行剩下列作为子数据集。
def splitDataSet(dataSet,axis,value):
'''
:param dataSet:数据集
:param axis: 列(特征所对应的列)
:param value: 需要返回的特征的值
:return: index列为value的数据集【数据分离后要删除index列】
'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
选择最好的数据集划分方式:
def chooseBestFeatureToSplit(dataSet):
"""
chooseBestFeatureToSplit(选择最好的特征)
Args:
dataSet 数据集
Returns:
bestFeature 最优的特征列
"""
# 求第一行有多少列的 Feature, 最后一列是label列
numFeatures = len(dataSet[0]) - 1
# 数据集的原始信息熵
baseEntropy = calcShannonEnt(dataSet)
# 最优的信息增益值, 和最优的Featurn编号
bestInfoGain, bestFeature = 0.0, -1
# iterate over all the features
for i in range(numFeatures):
# 获取对应的feature下的所有数据
featList = [example[i] for example in dataSet]
# 获取剔重后的集合,使用set对list数据进行去重
uniqueVals = set(featList)
# 创建一个临时的信息熵
newEntropy = 0.0
# 遍历某一列的value集合,计算该列的信息熵
# 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
# 计算概率
prob = len(subDataSet) / float(len(dataSet))
# 计算条件熵
newEntropy += prob * calcShannonEnt(subDataSet)
# gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
# 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
infoGain = baseEntropy - newEntropy
print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy)
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
构造树的数据结构,创建树的函数代码如下:
import operator
def majorityCnt(classList):
"""
majorityCnt(选择出现次数最多的一个结果)
Args:
classList label列的集合
Returns:
bestFeature 最优的特征列
"""
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 倒叙排列classCount得到一个字典集合,然后取出第一个就是结果(yes/no),即出现次数最多的结果
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
# 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。
# count() 函数是统计括号中的值在list中出现的次数
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
# 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 选择最优的列,得到最优列对应的label含义
bestFeat = chooseBestFeatureToSplit(dataSet)
# 获取label的名称
bestFeatLabel = labels[bestFeat]
# 初始化myTree
myTree = {bestFeatLabel: {}}
# 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
# 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
del (labels[bestFeat])
# 取出最优列,然后它的branch做分类
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
# 求出剩余的标签label
subLabels = labels[:]
# 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree()
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
储存决策树:
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def loadTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
使用Matplotlib注解绘制树形图:
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc='0.8')
leafNode=dict(boxstyle="round4",fc='0.8')
arrow_args=dict(arrowstyle="<-")
#绘制带箭头的注释
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs=0
# firstStr=myTree.keys()[0]
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else:
numLeafs+=1
return numLeafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.axl.text(xMid,yMid,txtString)
def plotTree(mytree,parentPt,nodeTxt):
numLeafs=getNumLeafs(mytree)
depth=getTreeDepth(mytree)
firstStr=list(mytree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=mytree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.axl=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
plt.show()
代码测试:
if __name__ == '__main__':
dataSet, labels = createDataSet()
shannonEnt = calcShannonEnt(dataSet)
mytree = createTree(dataSet, labels)
storeTree(mytree, 'classifierStorage.pkl')
mytree = loadTree('classifierStorage.pkl')
print(mytree)
createPlot(mytree)
输出结果:
决策树有以下优缺点: