下表中有5组数据,两个特征,根据着两组特征判断这个样本是不是鱼类。
不付出水面是否可以生存(no surfacing) | 是否有脚蹼(flippers) | 属于鱼类 | |
1 | 1 | 1 | yes |
2 | 1 | 1 | yes |
3 | 1 | 0 | no |
4 | 0 | 1 | no |
5 | 0 | 1 | no |
1.构造决策树
2.用matplotlib画出构造的决策树
3.给定一组数据,判断其分类。
#计算给定数据集的香农熵
from math import log
def calcShannonEnt(dataSet):
num=len(dataSet) #数据集的样本数量
labelCount={} #创建一个数据字典,它的键是数据集最后一列的数据,集样本的类别;它的值是该分类中的样本数量
#计算每种类别下的样本数量,并将其放在字典中对应的键下
for featureVec in dataSet:
label=featureVec[-1] #取样本中的最后一个值
if label not in labelCount.keys():
labelCount[label]=1
else:
labelCount[label]+=1
#计算数据集的熵
shannonEnt=0.0
for key in labelCount.keys():
pro=float(labelCount[key])/num
shannonEnt-=pro*log(pro,2)
return shannonEnt
#按照给定的特征划分数据集
def splitDataSet(dataSet,feature,value): #参数:带划分的数据集、划分数据集的特征、特征值
reDataSet=[]
for featureVector in dataSet:
if featureVector[feature]==value:
reduceFeature=featureVector[:feature]
reduceFeature.extend(featureVector[feature+1:])
reDataSet.append(reduceFeature)
return reDataSet
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numOfFeature=len(dataSet[0])-1
baseShannon=calcShannonEnt(dataSet) #
bestShannon=0.0
bestFeature=-1
for i in range(numOfFeature):
featureList=[featureVector[i]for featureVector in dataSet]#用列表推导式将第i个特征的值提取出来
featureSet=set(featureList) #利用集合的互异性找出特征的不同取值
newShannon=0.0
for value in featureSet:
subDataSet=splitDataSet(dataSet,i,value) #按照不同的特征划分数据集
#求新划分的数据集的香农熵
prob=float(len(subDataSet))/float(len(dataSet))
newShannon+=prob*calcShannonEnt(subDataSet)
shannon=baseShannon-newShannon
if(shannon>bestShannon):
bestShannon=shannon
bestFeature=i
return bestFeature
#多数表决法定义叶子节点的分类
import operator
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=1
else:
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
#递归构建决策树
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
#递归函数第一个停止的条件:所有类标签完全相同,直接返回该类标签
if classList.count(classList[0])==len(classList):
return classList[0]
#递归函数的第二个停止条件:使用完所有特征,仍不能将数据集划分成仅包含唯一类别的分组。使用多数表决法决定叶子节点的分类
if len(dataSet[0])==1:
return majorityCnt(classList)
#开始创建决策树
bestFeature=chooseBestFeatureToSplit(dataSet) #选择划分数据集最好的特征的索引
bestFeatureLabel=labels[bestFeature] #根据特征的索引提取索引的名称
decisionTree={bestFeatureLabel:{}} #将此特征作为树的根节点
del labels[bestFeature] #将已放进树中的特征从特征标签中删除
featrueValues=[example[bestFeature]for example in dataSet] #提取所有样本关于这个特征的取值
uniqueVals=set(featrueValues) #应用集合的互异性,提取这个特征的不同取值
for value in uniqueVals: #根据特征的不同取值,创建这个特征所对应结点的分支
subLabels=labels[:]
decisionTree[bestFeatureLabel][value]=createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
return decisionTree
#获取叶节点的数目,在绘制决策树时确定x轴的长度
def getNumLeafs(tree):
numOfLeaf=0
firstNode,=tree.keys()
second=tree[firstNode]
#测试节点的数据类型,若不是字典类型,则表示此节点为叶子节点
for key in second.keys():
if type(second[key]).__name__=='dict':
numOfLeaf+=getNumLeafs(second[key])
else:
numOfLeaf+=1
return numOfLeaf
#计算树的深度,在绘制决策树时确定y轴的高度
def getTreeDepth(tree):
depthOfTree=0
firstNode,=tree.keys()
second=tree[firstNode]
for key in second.keys():
if type(second[key]).__name__=='dict':
thisNodeDepth=getTreeDepth(second[key])+1
else:
thisNodeDepth=1
if thisNodeDepth>depthOfTree:
depthOfTree=thisNodeDepth
return depthOfTree
#用matplotlib绘制决策树
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle='sawtooth',fc='0.8') #决策节点;设置文本框的类型和文本框背景灰度,范围为0-1,0为黑,1为白,不设置默认为蓝色
leafNode=dict(boxstyle='round4',fc='1') #设置叶子节点文本框的属性
arrow_args=dict(arrowstyle='<-')
#绘制节点
#annotate(text,xy,xycoords,xytext,textcoords,va,ha,bbox,arrowprops)
#xy表示进行标注的点的坐标
#xytext表示标注的文本信息的位置
#xycoords与textcoords分别为xy和xytext的说明,默认为data
#va,ha设置文本框中文字的位置,va表示竖直方向,ha表示水平方向
def plotNode(nodeTxt,nodeIndex,parentNodeIndex,nodeType): #形参:文本内容,文本的中心点,箭头指向文本的点,点的类型
plt.annotate(nodeTxt,xy=parentNodeIndex,xycoords='axes fraction',
xytext=nodeIndex,textcoords='axes fraction',
va='center',ha='center',bbox=nodeType,
arrowprops=arrow_args)
#在父子节点之间添加注释
def plotMidText(thisNodeIndex,parentNodeIndex,text):
xmid=(parentNodeIndex[0]-thisNodeIndex[0])/2.0+thisNodeIndex[0]
ymid=(parentNodeIndex[1]-thisNodeIndex[1])/2.0+thisNodeIndex[1]
plt.text(xmid,ymid,text) #在指定位置添加注释
def plotTree(tree,parentNodeIndex,midTxt):
global xOff
global yOff
numOfLeafs=getNumLeafs(tree)
nodeTxt,=tree.keys()
nodeIndex=(xOff+(1.0+float(numOfLeafs))/2.0/treeWidth,yOff) #计算节点的位置
plotNode(nodeTxt, nodeIndex, parentNodeIndex, decisionNode)
plotMidText(nodeIndex,parentNodeIndex,midTxt)
secondDict=tree[nodeTxt]
yOff=yOff-1.0/treeDepth
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],nodeIndex,str(key))
else:
xOff=xOff+1.0/treeWidth
plotNode(secondDict[key],(xOff,yOff),nodeIndex,leafNode)
plotMidText((xOff,yOff),nodeIndex,str(key))
yOff=yOff+1.0/treeDepth
def createPlot(tree): #绘制决策树的主函数
fig=plt.figure('DecisionTree',facecolor='white') #创建一个画布,命名为'decisionTree',画布颜色为白色
fig.clf() #清空画布
createPlot.ax1=plt.subplot(111,frameon=False) #111:将画布分成1行1列,去第一块画布;frameon:是否绘制矩形坐标框
#设置两个全局变量xOff和yOff,追踪已绘制节点的位置,计算放置下一个节点的恰当位置。
global xOff
xOff=-0.5/treeWidth
global yOff
yOff=1.0
plotTree(tree,(0.5,1.0),'')
plt.xticks([])
plt.yticks([])
plt.show()
注:
1.treeWidth和treeDepth是我们在函数外声明的变量,用于存储树的宽度和深度。我们使用这两个变量计算树节点的摆放位置,这样可以讲述绘制在水平方向和竖直方向的中心位置。
2.代码中声明了xOff和yOff两个全局变量来追踪已绘制的节点位置,以及放置下一个节点的恰当位置。
#使用决策树执行分类
def classify(inputTree,featureLabels,testVector):
firstNode,=inputTree.keys()
secondDict=inputTree[firstNode]
featureIndex=featureLabels.index(firstNode)
for key in secondDict.keys():
if testVector[featureIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featureLabels,testVector)
else:
classLabel=secondDict[key]
return classLabel
构造决策树是很耗时的任务,即使处理很小的数据集,也要花费好几秒的时间。为了节省构造数据集的时间,最好在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块的pickle序列化对象。序列化对象可以在磁盘上保存对象,并在需要的时候读出来。
实现如下:
#使用pickle模块储存决策树
def storeTree(inputTree,filename):
import pickle
file=open(filename,'wb')
pickle.dump(inputTree,file)
file.close()
def loadTree(filename):
import pickle
file=open(filename,'rb')
Tree=pickle.load(file)
file.close()
return Tree
dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']];
labels=['no surfacing','flippers']
decisionTree=createTree(dataSet,labels)
storeTree(decisionTree,'decisionTree')
myTree=loadTree('decisionTree')
featureLabels=['no surfacing','flippers']
treeWidth=float(getNumLeafs(myTree))
treeDepth=float(getTreeDepth(myTree))
createPlot(myTree)
print(classify(myTree,featureLabels,[1,0]))
构造决策树图的结果如下:
输入数据判断分类的输出结果:
no
1.calcShannonEnt(dataSet):计算香农熵
2.splitDataSet(dataSet,feature,value):划分数据集
3.chooseBestFeatureToSplit(dataSet):选择最好的数据集划分方式
4.majorityCnt(classList);createTree(dataSet,labels):多数表决法+递归构建决策树
5.python模块pickle序列化对象的应用:pickle模块
决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据时,首先需要测量集合种数据的不一致性,也就是熵,然后寻找最优方案划分数据集,知道找到数据集中的所有数据属于同一类。本篇文章中用于构造决策树的算法为ID3算法,这个算法无法直接处理数值型的数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但是存在太多的特征划分时,ID3算法仍然会存在一些其他的问题。
构建决策树时,一般不构造新的数据结构,而是使用python语言内嵌的数据结构字典存储树节点信息。
使用matplotlib的注解功能,可以将存储的树结构转化为容易理解的图形。python语言的pickle模块可用于存储决策树的结构。决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过剪裁决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配的问题(这种方法,上述例子中没有涉及)。
关于构造决策树的算法,还有C4.5和CART。