决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本分类的任务,可看作对“当前样本属于正类吗?”这个问题的“决策”或“判定”过程.顾名思义,决策树是基于树结构来进行决策的,这恰是人类在面临决策问题时一种很自然的处理机制.
一般的,一棵决策树包含一个根结点、若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试;每个结点包含的样本集合根据属性测试的结果被划分到子结点中;根结点包含样本全集.从根结点到每个叶结点的路径对应了一个判定测试序列.决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树,其基本流程遵循简单且直观的“分而治之”(divide-and-conquer)策略,如图:
划分数据集的大原则是:将无需的数据变得更加有序;
“信息嫡”(information entropy)是度量样本集合纯度最常用的一种指标.假定当前样本集合D中第k类样本所占的比例为pk (k = 1,2,…,|y|),则D的信息嫡定义为
Ent(D)的值越小,则D的纯度越高.值越高信息越混乱。
假定离散属性a有V个可能的取值{a1 , a2,… . ,aV},若使用a来对样本集D进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为av的样本,记为D”.我们可根据上式计算出Dv的信息嫡,再考虑到不同的分支结点所包含的样本数不同,给分支结点赋予权重|Dv|/|D|,即样本数越多的分支结点的影响越大,于是可计算出用属性α对样本集D进行划分所获得的“信息增益”(information gain)
一般而言,信息增益越大,则意味着使用属性α来进行划分所获得的“纯度提升”越大.因此,我们可用信息增益来进行决策树的划分属性选择,即在上式算法选择属性a*= argmaxGain(D,a)、著名的ID3决策树学习算就是以信息增益为准则来选择划分属性.
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感;
缺点:可能产生过度匹配问题
使用上次实验的数据,每个列分别表示:[‘卷面成绩’,‘班级排名’,‘综测分’,‘奖学金等级’]
其中 奖学金等级 为我们将要预测的标签。
可以看到属性的数据类型都为连续型,由于连续属性的可取值数目不再有限,因此不能直接对连续型属性的可取值来对节点进行划分,接下来要将数据离散化 。
最简单暴力的方法,排序后按比例划分n类(暂时不考虑信息增益):
#离散化数据
#data数据 axis待划分属性 num划分几类
def discretedata(data,axis,num):
len_data=len(data)
#求该属性的最大值和最小值
max=0;min=100
for i in range(len_data):
if data[i][axis]>max:
max=data[i][axis]
if data[i][axis]<min:
min=data[i][axis]
#划分区间
degree=(max-min)*1.0/num
limit=[min]
for i in range(num):
limit.append(min+degree*(i+1))
print(limit)
#修改对应的值,1代表最低级
for one_data in data:
for i in range(len(limit)-1):
if one_data[axis]>=limit[i] and one_data[axis]<=limit[i+1] :
one_data[axis]=i+1
print(data)
可以看到第一个属性被分为5个等级(根据数据集实际情况调整),接下来对剩余属性进行相同操作
mydata=createData('./ch1/grade.txt')
for i in range(3):
discretedata(mydata,i,6)#这里分为6个等级
print(mydata)
#计算给定数据的熵
def calcEnt(data):
num=len(data)#获取数据的行数
labels={}
#遍历数据的所有行
for featVec in data:
#为所有可能分类创建字典
currentLabel=featVec[-1]
if currentLabel not in labels.keys():#如果字典中不存在该键则创建并初始化为0
labels[currentLabel]=0
labels[currentLabel]+=1#统计每个类别
Ent=0.0
for key in labels:
prob=float(labels[key])/num
Ent-=prob*log2(prob)
return Ent
print("计算当前数据的熵",calcEnt(mydata))
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):# 待划分的数据集 划分数据集的特征 需要返回的特征的值
retDataSet=[]
for featVec in dataSet:
if featVec[axis] ==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])#放列表中的元素
retDataSet.append(reducedFeatVec)#把整个列表放入
return retDataSet
if __name__=='__main__':
labels=['卷面成绩','班级排名','综测分','奖学金等级']
mydata=createData('./ch1/grade.txt')
for i in range(3):
discretedata(mydata,i,6)
print("计算当前数据的熵",calcEnt(mydata))
print("按奖学金为1划分测试:",splitDataSet(mydata,3,1))
def chooseBest(dataset):
num=len(dataset[0])-1 #除去标签剩余数量
baseEnt=calcEnt(dataset)#计算整个数据的香农熵
bestGain=0.0
bestFeature=-1
for i in range(num):#0 1 2
#创建唯一的分类标签列表
featlist=[example[i] for example in dataset]#对每行数据取i列放入featlist
uniqueVals=set(featlist)#集合去重复
newEnt=0.0
for value in uniqueVals:
#计算每种划分方式的信息熵
subDataSet=splitDataSet(dataset,i,value)
prob=len(subDataSet)/float(len(dataset))
newEnt+=prob*calcEnt(subDataSet)
inforGain=baseEnt-newEnt
#计算最好的信息熵
if (inforGain>bestGain):
bestGain=inforGain
bestFeature=i
return bestFeature
采用递归的原则处理数据集,但程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类时递归结束。
如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,通常会采用多数表决的方法决定叶子节点的分类
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
print(sortedClassCount)
return sortedClassCount[0][0]
创建树的代码:
使用python语句中的字典类型存储树的信息
def createTree(dataSet,labels):
#类别完全相同则停止划分
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
#遍历完所有特征时返回出现次数最多的类别
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBest(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])#删除
featValues=[example[bestFeat] for example in dataSet ]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:] #复制一份
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
if __name__=='__main__':
labels=['卷面成绩','班级排名','综测分','奖学金等级']
mydata=createData('./ch1/grade.txt')
for i in range(3):
discretedata(mydata,i,6)
mytree=createTree(mydata,labels)
print(mytree)
此处参考《机器学习实战》代码
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc='0.8')
leafNode=dict(boxstyle="round4",fc='0.8')
arrow_args=dict(arrowstyle="<-")
#绘制带箭头的注释
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs=0
# firstStr=myTree.keys()[0]
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else:
numLeafs+=1
return numLeafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.axl.text(xMid,yMid,txtString)
def plotTree(mytree,parentPt,nodeTxt):
numLeafs=getNumLeafs(mytree)
depth=getTreeDepth(mytree)
firstStr=list(mytree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=mytree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.axl=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
plt.show()
从数据划分上看:班级排名越低越好;综测分越高越好;卷面成绩越高越好
1.在绘制的决策树中可以看出,班级排名为首选分类属性(越低表示班级排名越高,越容易拿到奖学金),这也符合实际。
2.第二个属性为综测分,最右边的节点可以看出,当综测分等级为3时(最高),综测分低于2等级的都不可能拿到奖学金(分支下所有实例都具有相同类别),等于3时有可能拿到3等奖学金
3.后期可以根据熵来处理连续型属性,当前采用3等级划分已经符合实际了,下次实验有待改进。
附上本次完整代码和数据集:link