本篇决策树算法是依据ID3算法来的,所以在看之间建议先了解ID3算法:https://blog.csdn.net/qq_27396861/article/details/88226296
案例,按照属性来分辨海洋生物:
实例:
# coding: utf-8
from math import log
import operator
def createDataSet():
dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(dataSet):
''' 计算给定总的数据集的香农熵 '''
numEntries = len( dataSet ) # 集合里元素的数量
labelCounts = {}
# 为所有可能的分类创建次数字典
for featVec in dataSet:
currentLabel = featVec[-1] # 倒数第一个元素
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# 将当前样本的数量除以总的样本数量
prob = float(labelCounts[key]) / numEntries
# 以2为底求对数,然后求和
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
''' 按照给定特征划分数据集, 待划分的数据集 划分数据集的特征(dataSet里特征的下标) 特征的返回值 '''
''' 如果axis下标是0,那么reduceFeatVec就是后两个,如果是1,那么就是第一个和第三个。也就是除了axisd的下标'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
''' 选择最好的数据集划分方式 '''
numFeatures = len(dataSet[0]) - 1 # 每个元素的特征个数
# 香农熵
baseEntropy = calcShannonEnt(dataSet) # 求取数据集合的香农熵
bestInfoGain = 0.0 # 最好的熵
bestFeature = -1 # 最好的特征
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # 符合的数量 / 总的数量
# 条件熵
newEntropy += prob * calcShannonEnt(subDataSet) # 划分完之后的信息熵,相加
#信息增益
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
# 返回0是不浮出水面也能生活,1是否有脚蹼
return bestFeature
def majorityCnt(classList):
''' 找出数量最多的分类 '''
# 分类字典
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 以第二列的数据排序
sortedClassCount = sorted(classCount.iteritems(), \
key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
''' 创建树的函数代码 '''
classList = [example[-1] for example in dataSet] # 提取所有的类
print "classList = ", classList
# 数据集都是同一类的情况
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果数据集只有一个特征的情况
if len(dataSet[0]) == 1:
return majorityCnt(classList) # 那就按大多数的分类
bestFeat = chooseBestFeatureToSplit( dataSet ) # 最好的特征
bestFeatLabel = labels[bestFeat] # 最好的分类
myTree = {bestFeatLabel:{}}
# 递归建树
del( labels[bestFeat] )
featValue = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValue) # 最好的特征集合
for value in uniqueVals:
subLabels = labels[:] # 去掉前面标签之后剩下的标签
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet, bestFeat, value), subLabels)
return myTree
def main():
dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
print myTree
if __name__=="__main__":
main()
结果与上述图二一致:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
由于代码的思想是将x轴和y轴平均分为几等分,且长度都是都是0到1,totalW记录叶子结点的个数,那么 1/totalW 正好是每个叶子结点的宽度。
以下代码是构建决策树视图的,难点在于两处地方:
(1)初始化x轴坐标的时候:
由于要使每次加1/totalW 正好是每个跟节点的增加的距离是在叶子节点之间,故要设置x初始位置要减去 0.5 / totalW,以至于后面每加一个1/totalW ,刚好都是在中间。
plotTree.xOff = -0.5 / plotTree.totalW
(2)计算跟节点x轴坐标的时候
这个需要分步理解:
X坐标=节点的x偏移量 + 叶节点数距离
所有该节点下子叶子节点的距离:numLeafs / plotTree.totalW
但是坐标在叶子节点的中心:numLeafs / 2 / plotTree.totalW
又因为xOff初始坐标点在原点的左边:numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW
,这是偏移量
那么x = numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW + plotTree.xOff
plotTree.xOff + (1.0 + float(numLeafs) ) / 2.0 / plotTree.totalW
代码如下:
# coding: utf-8
import matplotlib.pyplot as plt
import sys
# 定义文本框和箭头格式
decisionNode = dict( boxstyle = "sawtooth", fc = "1.0" )
leafNode = dict( boxstyle = "round4", fc = "0.8" )
arrow_args = dict( arrowstyle = "<-" )
# 绘制带箭头的注释
def plotNode( nodeTxt, centerPt, parentPt, nodeType ):
createPlot.ax1.annotate( nodeTxt, xy = parentPt, xycoords = 'axes fraction',\
xytext = centerPt, textcoords = 'axes fraction', va = 'center',\
ha = 'center', bbox = nodeType, arrowprops = arrow_args )
'''
def createPlot():
fig = plt.figure( 1, facecolor = 'white' )
fig.clf()
createPlot.ax1 = plt.subplot( 111, frameon = False )
plotNode( "decision node", (0.5, 0.1), (0.1, 0.5), decisionNode )
plotNode( "leaf node", (0.8, 0.1), (0.3, 0.8), leafNode )
plt.show()
'''
def getNumLeafs( myTree ):
''' 获取叶子节点数 '''
numLeafs = 0
firstStr = myTree.keys()[0] # 头结点
secondDict = myTree[ firstStr ] # 取出头结点的的字典
# 测试节点的数据类型是否为字典
for key in secondDict.keys():
if type( secondDict[key] ).__name__=='dict':
numLeafs += getNumLeafs( secondDict[key] )
else:
numLeafs += 1
return numLeafs
def getTreeDepth( myTree ):
''' 求树的深度 '''
maxDepth = 0
firstStr = myTree.keys()[0] # 头结点
secondDict = myTree[ firstStr ]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth( secondDict[key] )
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [ {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},\
{'no surfacing': {0: 'no', 1: {'flippers': {0: \
{'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText( cntrPt, parentPt, txtString ):
''' 在父子节点之间填充文本信息 '''
xMid = ( parentPt[0] - cntrPt[0] ) / 2.0 + cntrPt[0]
yMid = ( parentPt[1] - cntrPt[1] ) / 2.0 + cntrPt[1]
createPlot.ax1.text( xMid, yMid, txtString )
def plotTree( myTree, parentPt, nodeTxt ):
''' parentPt 根节点坐标 '''
numLeafs = getNumLeafs( myTree ) # 子节点数量
depth = getTreeDepth( myTree ) # 深度
firstStr = myTree.keys()[0] # 根节点的key
'''
X坐标=节点的x偏移量 + 叶节点数距离
所有该节点下子叶子节点的距离:numLeafs / plotTree.totalW
但是坐标在叶子节点的中心:numLeafs / 2 / plotTree.totalW
又因为xOff初始坐标点在原点的左边:numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW ,这是偏移量
那么x = numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW + plotTree.xOff
'''
# 根节点坐标
# 叶子节点距离
x = plotTree.xOff + (1.0 + float(numLeafs) ) / 2.0 / plotTree.totalW
#x = plotTree.xOff + (float(numLeafs) ) / 2.0 / plotTree.totalW
y = plotTree.yOff
print "x = %f, y = %f" % (x, y)
cntrPt = ( x, y )
# 标记子节点属性值
plotMidText( cntrPt, parentPt, nodeTxt )
plotNode( firstStr, cntrPt, parentPt, decisionNode )
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree( secondDict[key], cntrPt, str(key) )
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode( secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode )
plotMidText( (plotTree.xOff, plotTree.yOff), cntrPt, str(key) )
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot( inTree ):
fig = plt.figure( 1, facecolor='white' )
fig.clf()
axprops = dict( xticks = [], yticks = [] )
createPlot.ax1 = plt.subplot( 111, frameon = False, **axprops )
# 给plotTree函数建立属性
plotTree.totalW = float( getNumLeafs( inTree ) ) # 宽度是叶子节点数
plotTree.totalD = float( getTreeDepth( inTree ) ) # 高度是深度
plotTree.xOff = -0.5 / plotTree.totalW # 方便后面加上 1.0 / plotTree.totalW 后位置刚好在中间
plotTree.yOff = 1.0
plotTree( inTree, (0.5, 1.0), '' )
plt.show()
def main():
# createPlot()
myTree = retrieveTree(1)
createPlot( myTree )
if __name__=="__main__":
main()