trees.py
import operator
from math import log
import treePlotter as dtPlot
from collections import Counter
def createDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(dataSet):
"""
计算给定数据集的香农熵
:param dataSet:数据集
:return:每一组feature下的某个分类下,香农熵的信息期望
"""
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel1 = featVec[-1]
if currentLabel1 not in labelCounts.keys():
labelCounts[currentLabel1] = 0
labelCounts[currentLabel1] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, index, value):
"""
就是依据index列进行分类,如果index列的数据等于 value,就要将 index 划分到我们创建的新的数据集中
说白了就是通过index特征分类,并将特征从数据中消除
:param dataSet:数据集 待划分的数据集
:param index:表示每一行的index列 划分数据集的特征
:param value:表示index列对应的value值 需要返回的特征的值。
:return:index列为value的数据集【该数据集需要排除index列】
"""
retDataSet = []
for featVec in dataSet:
if featVec[index] == value:
reducedFeatVec = featVec[:index]
reducedFeatVec.extend(featVec[index + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
"""
选择最好的特征
:param dataSet:数据集
:return:最优的特征列
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain, bestFeature = 0.0, -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDateSet = splitDataSet(dataSet, i, value)
prob = len(subDateSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDateSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
"""
选择出现次数最多的一个结果
:param classList:列的集合
:return:bestFeature 最优的特征列
"""
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del labels[bestFeat]
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
"""
给输入的节点,进行分类
:param inputTree:决策树模型
:param featLabels:Feature标签对应的名称
:param testVec:测试输入的数据
:return:classLabel 分类的结果值,需要映射label才能知道名称
"""
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
print('+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat)
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grapTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
def fishTest():
myDat, labels = createDataSet()
import copy
myTree = createTree(myDat, copy.deepcopy(labels))
print(myTree)
print(classify(myTree, labels, [1, 1]))
print('树高:', get_tree_height(myTree))
dtPlot.createPlot(myTree)
def get_tree_height(tree):
"""
Desc:
递归获得决策树的高度
Args:
tree
Returns:
树高
"""
if not isinstance(tree, dict):
return 1
child_trees = list(tree.values())[0].values()
max_height = 0
for child_tree in child_trees:
child_tree_height = get_tree_height(child_tree)
if child_tree_height > max_height:
max_height = child_tree_height
return max_height + 1
def ContactLensesTest():
"""
Desc:
预测隐形眼镜的测试代码
Returns:
none
"""
fr = open('data/lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
dtPlot.createPlot(lensesTree)
if __name__ == "__main__":
ContactLensesTest()
treePlotter.py
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot(inTree):
fig = plt.figure(1, facecolor='green')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) is dict:
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) is dict:
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
firstStr = list(myTree.keys())[0]
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) is dict:
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD