收集数据:可以使用任何方法。
准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
分析数据:可以使用任何方法,构造树完成后,我们应该检查图形是否符合预期。
训练算法:构造树的数据结构。
测试算法:使用经验树计算错误率。
使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
'''
函数功能:生成一个简易数据集
'''
def createDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
计算给定数据集的香农熵:
'''
函数功能:计算给定数据集的香农熵
'''
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 样本数
labelCounts = {} # 创建一个数据字典:key是最后一列的数值(即标签,也就是目标分类的类别),value是属于该类别的样本个数
for featVec in dataSet: # 遍历整个数据集,每次取一行
currentLabel = featVec[-1] # 取该行最后一列的值
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0 # 初始化信息熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * math.log(prob, 2) # log base 2 计算信息熵
return shannonEnt
对每个特征划分数据集的结果计算一次信息熵,判断按照哪个特征划分数据集是最好的划分方式。
'''
函数功能:划分数据集
'''
def splitDataSet(dataSet, axis, value): #axis是dataSet数据集下要进行特征划分的列号例如outlook是0列,value是该列下某个特征值,0列中的sunny
retDataSet = []
for featVec in dataSet: #遍历数据集,并抽取按axis的当前value特征进划分的数据集(不包括axis列的值)
if featVec[axis] == value: #
reducedFeatVec = featVec[:axis] #减去axis用来进行划分
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
'''
函数功能:选取当前数据集下,用于划分数据集的最优特征
'''
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #获取当前数据集的特征个数,最后一列是分类标签
baseEntropy = calcShannonEnt(dataSet) #计算当前数据集的信息熵
bestInfoGain = 0.0; bestFeature = -1 #初始化最优信息增益和最优的特征
for i in range(numFeatures): #遍历每个特征
featList = [example[i] for example in dataSet]#获取数据集中当前特征下的所有值
uniqueVals = set(featList) # 获取当前特征值,例如outlook下有sunny、overcast、rainy
newEntropy = 0.0
for value in uniqueVals: #计算每种划分方式的信息熵
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy #计算信息增益
if infoGain > bestInfoGain: #比较每个特征的信息增益,只要最好的信息增益
bestInfoGain = infoGain #如果比当前最好的更好,选择最好的
bestFeature = i
return bestFeature #返回一个整数
定义一个多数表决函数majorityCnt()
'''
函数功能:当遍历完所有的特征属性后,类标签仍然不唯一(分支下仍有不同分类的实例)
采用多数表决的方法完成分类
该函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典。
字典对象存储了classList中每个类标签出现的频率。最后利用operator操作键值排序字典,并返回出现次数最多的分类名称
'''
def majorityCnt(classList):
classCount={}#创建一个类标签的字典
#遍历类标签列表中的每一个元素
for vote in classList:
#如果元素不在字典中
if vote not in classCount.keys():
#在字典中添加新的键值对
classCount[vote] = 0
#否则当前键对应的值加1
classCount[vote] += 1
#对字典中的键对应的值所在的列,按照由大到小进行排序
#classCount.items 列表对象
#key = opreator.itemgetter(1) 获取列表对象的第一个域的值
#reverse = true 降序排序,默认是升序排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
#返回出现次数最多的列标签
return sortedClassCount[0][0]
通过递归的方式写出决策树的构建代码
'''
函数功能:生成决策树主方法
'''
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet] # 获取数据集中的最后一列的类标签,存入classList列表
#通过count()函数获取类标签列表中的第一个类标签的数目
#判断数目是否等于列表长度,相同表面所有类标签相同,属于同一类
if classList.count(classList[0]) == len(classList):#当类别完全相同时则停止继续划分,直接返回该类的标签
return classList[0]
#遍历完所有的特征属性,此时数据集的列为1,即只有类标签列
if len(dataSet[0]) == 1:
#多数表决原则,确定类标签
return majorityCnt(classList)
#确定出当前最优的分类特征
bestFeat = chooseBestFeatureToSplit(dataSet)
#在特征标签列表中获取该特征对应的值
bestFeatLabel = labels[bestFeat]
#采用字典嵌套字典的方式,存储分类树信息
# 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要。
myTree = {bestFeatLabel:{}}
del(labels[bestFeat]) #删除已经在选取的特征
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
#递归调用createTree()函数,并且将返回的tree插入到myTree字典中
#利用最好的特征划分的子集作为新的dataSet传入到createTree()函数中
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
测试决策树模型
'''
函数功能:测试决策树模型
'''
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
决策树的存储与调用
'''
函数功能:存储决策树模型
'''
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
'''
函数功能:调用决策树模型
'''
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
决策树的可视化:
'''
函数功能:输出预先存储的树信息
'''
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':\
{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'filppers':\
{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#annotate是一个数据点的文本
#nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
createPlot.ax1.annotate(nodeTxt,xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction',
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
'''
函数功能:获取叶子节点的数目
'''
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
'''
函数功能:获取树的层数
'''
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:maxDepth = thisDepth
return maxDepth
'''
函数功能:在父子节点中填充文本信息
'''
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
'''
函数功能:生成决策树图形
'''
def createPlot(inTree):
fig = plt.figure(1,facecolor = "white")
fig.clf()
axprops = dict(xticks = [],yticks = [])
#createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图
#111表示figure中的图有1行1列,即1个,最后的1代表第一个图
#frameon表示是否绘制坐标轴矩形
createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,0.1),'')
plt.show()
'''
函数功能:绘制决策树
'''
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr, cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.xOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD