import numpy as np
import matplotlib.pyplot as plt
from pylab import *
# 特征字典,后面用到了好多次,干脆当全局变量了
featureDic = {
'色泽': ['浅白', '青绿', '乌黑'],
'根蒂': ['硬挺', '蜷缩', '稍蜷'],
'敲声': ['沉闷', '浊响', '清脆'],
'纹理': ['清晰', '模糊', '稍糊'],
'脐部': ['凹陷', '平坦', '稍凹'],
'触感': ['硬滑', '软粘']}
def getDataSet():
"""
get watermelon data set 3.0 alpha.
:return: 编码好的数据集以及特征的字典。
"""
dataSet = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460, 1],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.774, 0.376, 1],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.634, 0.264, 1],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.608, 0.318, 1],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.556, 0.215, 1],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.403, 0.237, 1],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 0.481, 0.149, 1],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 0.437, 0.211, 1],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 0.666, 0.091, 0],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 0.243, 0.267, 0],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, 0],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 0.343, 0.099, 0],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 0.639, 0.161, 0],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 0.657, 0.198, 0],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.360, 0.370, 0],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 0.593, 0.042, 0],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 0.719, 0.103, 0]
]
features = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖量']
# #得到特征值字典,本来用这个生成的特征字典,还是直接当全局变量方便
# featureDic = {}
# for i in range(len(features)):
# featureList = [example[i] for example in dataSet]
# uniqueFeature = list(set(featureList))
# featureDic[features[i]] = uniqueFeature
# 每种特征的属性个数
numList = [] # [3, 3, 3, 3, 3, 2]
for i in range(len(features) - 2):
numList.append(len(featureDic[features[i]]))
# 编码,把文字替换成数字。用1、2、3表示同种特征的不同类型
newDataSet = []
for dataVec in dataSet: # 第一每一个数据
dataNum = dataVec[-3:] # 保存数据中的数值部分
newData = []
for i in range(len(dataVec) - 3): # 值为字符的每一列
for j in range(numList[i]): # 对应列的特征的每一类
if dataVec[i] == featureDic[features[i]][j]:
newData.append(j+1)
newData.extend(dataNum) # 编码好的部分和原来的数值部分合并
newDataSet.append(newData)
return np.array(newDataSet), features
# # test getDataSet()
# newData, features = getDataSet()
# print(newData)
# print(features)
def calEntropy(dataArr, classArr):
"""
calculate information entropy.
:param dataArr:
:param classArr:
:return: entropy
"""
n = dataArr.size
data0 = dataArr[classArr == 0]
data1 = dataArr[classArr == 1]
p0 = data0.size / float(n)
p1 = data1.size / float(n)
# 约定:p=0, p*log_2(p) = 0
if p0 == 0:
ent = -(p1 * np.log(p1))
elif p1 == 0:
ent = -(p0 * np.log(p0))
else:
ent = -(p0 * np.log2(p0) + p1 * np.log2(p1))
return ent
# # test calEntropy()
# dataSet, _ = getDataSet()
# print(calEntropy(dataSet[:, :-1], dataSet[:, -1]))
def splitDataSet(dataSet, ax, value):
"""
按照给点的属性ax和其中一种取值value来划分数据。
当属性类型为标称数据时,返回一个属性值都为value的数据集。
当属性类型为数值型数据事,以与value的大小关系为基准返回两个数据集。
input:
dataSet: 输入数据集,形状为(m,n)表示m个数据,前n-1列个属性,最后一列为类型。
ax:属性类型
value: 标称型时为1、2、3等。数值型为形如0.123的数。
return:
1.标称型dataSet返回第ax个属性中值为value组成的集合
2.数值型dataSet返回两个集合。其一中数据都小于等于value,另一都大于。
"""
# 2个连续属性密度、含糖量+类型为后3列,其余为标称型
if ax < dataSet.shape[1] - 3:
dataS = np.delete(dataSet[dataSet[:, ax] == value], ax, axis=1)
return dataS
else:
dataL = dataSet[dataSet[:, ax] <= value]
dataR = dataSet[dataSet[:, ax] > value]
return dataL, dataR
# # test splitDataSet()
# dataSet, _ = getDataSet()
# test1 = splitDataSet(dataSet, 3, 1)
# test2L, test2R = splitDataSet(dataSet, 6, 0.5)
# print("test1 = ", test1)
# print("test2L = ", test2L)
# print("test2R = ", test2R)
def calInfoGain(dataSet, labelList, ax, value=-1):
"""
计算给定数据dataSet在属性ax上的香农熵增益。
input:
dataSet:输入数据集,形状为(m,n)表示m个数据,前n-1列个属性,最后一列为类型。
labelList:属性列表,如['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖量']
ax: 选择用来计算信息增益的属性。0表示第一个属性,1表示第二个属性等。
前六个特征是标称型,后两个特征是数值型。
value: 用来划分数据的值。当标称型时默认为-1, 即不使用这个参数。
return:
gain:信息增益
"""
baseEnt = calEntropy(dataSet[:, :-1], dataSet[:, -1]) # 计算D的原始信息熵
newEnt = 0.0 # 划分完数据后的香农熵
if ax < dataSet.shape[1] - 3: # 计算标称型的香农熵
num = len(featureDic[labelList[ax]]) # 每一个特征的类别数
for j in range(num):
subDataSet = splitDataSet(dataSet, ax, j+1)
prob = len(subDataSet) / float(len(dataSet))
if prob != 0:
newEnt += prob * calEntropy(subDataSet[:, :-1], subDataSet[:, -1])
else:
# 数据集划分为两份
dataL, dataR = splitDataSet(dataSet, ax, value)
# 计算两数据集的信息熵
entL = calEntropy(dataL[:, :-1], dataL[:, -1])
entR = calEntropy(dataR[:, :-1], dataR[:, -1])
# 计算划分完总数据集的信息熵
newEnt = (dataL.size * entL + dataR.size * entR) / float(dataSet.size)
# 计算信息增益
gain = baseEnt - newEnt
return gain
# # test calInfoGain(dataSet, featureDic, axis, value=-1):
# data, feat = getDataSet()
# print(calInfoGain(data, feat, 2))
def chooseBestSplit(dataSet, labelList):
"""
计算信息增益增大的划分数据集的方式. 当返回的不是数值型特征时, 划分值bestThresh = -1
input:
dataSet
labelList
return:
bestFeature: 使得到最大增益划分的属性。
bestThresh: 使得到最大增益划分的数值。标称型时无意义令其为-1。
maxGain: 最大增益划分时的增益值。
"""
maxGain = 0.0
bestFeature = -1
bestThresh = -1
m, n = dataSet.shape
# 对每一个特征
for i in range(n - 1):
if i < (n - 3): # 标称型
gain = calInfoGain(dataSet, labelList, i)
if gain > maxGain:
bestFeature = i
maxGain = gain
else: # 数值型
featVals = dataSet[:, i] # 得到第i个特征的所有值
sortedFeat = np.sort(featVals) # 按照从小到大的顺序排列第i个特征的所有值
T = []
# 计算划分点
for j in range(m - 1):
t = (sortedFeat[j] + sortedFeat[j + 1]) / 2.0
T.append(t)
# 对每一个划分值,计算增益熵
for t in T:
gain = calInfoGain(dataSet, featureDic, i, t)
if gain > maxGain:
bestFeature = i
bestThresh = t
maxGain = gain
return bestFeature, bestThresh, maxGain
# # test chooseBestSplit
# data, feat = getDataSet()
# f, tv, g = chooseBestSplit(data, feat)
# print(f"best feature is {list(featureDic.keys())[f]}\n"
# f"best thresh value is {tv}\n"
# f"max information gain is {g}")
def majorityCnt(classList):
"""
投票,0多返回"坏瓜",否则返回"坏瓜"。
"""
cnt0 = len(classList[classList == 0])
cnt1 = len(classList[classList == 1])
if cnt0 > cnt1:
return '坏瓜'
else:
return '好瓜'
def createTree(dataSet, labels):
"""
通过信息增益递归创造一颗决策树。
input:
labels
dataSet
return:
myTree: 返回一个存有树的字典
"""
classList = dataSet[:, -1]
# 如果剩余的类别全相同,则返回
if len(classList[classList == classList[0]]) == len(classList):
if classList[0] == 0:
return '坏瓜'
else:
return '好瓜'
# 如果只剩下类标签,投票返回
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 得到增益最大划分的属性、值
bestFeat, bestVal, entGain = chooseBestSplit(dataSet, labels)
bestFeatLabel = labels[bestFeat]
if bestVal != -1: # 如果是数值型
txt = bestFeatLabel + "<=" + str(bestVal) + "?"
else: # 如果是标称型
txt = bestFeatLabel + "=" + "?"
myTree = {txt: {}} # 创建字典,即树的节点。
if bestVal != -1: # 数值型的话就是左右两个子树。
subDataL, subDataR = splitDataSet(dataSet, bestFeat, bestVal)
myTree[txt]['是'] = createTree(subDataL, labels)
myTree[txt]['否'] = createTree(subDataR, labels)
else:
i = 0
# 生成子树的时候要将已遍历的属性删去。数值型不要删除。
del (labels[bestFeat])
uniqueVals = featureDic[bestFeatLabel] # 最好的特征的类别列表
for value in uniqueVals: # 标称型的属性值有几种,就要几个子树。
# Python中列表作为参数类型时,是按照引用传递的,要保证同一节点的子节点能有相同的参数。
subLabels = labels[:] # subLabels = 注意要用[:],不然还是引用
i += 1
subDataSet = splitDataSet(dataSet, bestFeat, i)
myTree[txt][value] = createTree(subDataSet, subLabels)
return myTree
# # test createTree()
# data, feat = getDataSet()
# Tree = createTree(data, feat)
# print(Tree)
# ***********************画图***********************
# **********************start***********************
# 详情参见机器学习实战决策树那一章
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 没有这句话汉字都是口口
# mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, fontsize=20)
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
createPlot.ax1.annotate(nodeTxt,
xy=parentPt,
xycoords="axes fraction",
xytext=centerPt,
textcoords="axes fraction",
va="center",
ha="center",
bbox=nodeType,
arrowprops=arrow_args,
fontsize=20)
def getNumLeafs(myTree): # 获取叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 获取树的层数
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, figsize=(600, 30), facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# ***********************画图***********************
# ***********************end************************
def main():
dataSet, labelList = getDataSet()
myTree = createTree(dataSet, labelList)
createPlot(myTree)
if __name__ == '__main__':
main()
一开始做错了,用的只有密度和含糖率的那个数据,定义了个二叉树节点的类,结果改了半天机器学习实战那个画树的代码,美滋滋完发现不是那回事。。。节点并不是两个,就当写成CART了吧。又改成了用字典保存树的代码,之前写的就也保存一下。
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
# define tree node
class TreeNode:
def __init__(self, feature, thresh):
self.feature = feature # 特征:密度 or 含糖率
self.thresh = thresh # 基于某个特征分类时的划分值
self.label = -1 # 类别:只在叶子节点上不为-1,为0或1代表好瓜的否和是。
self.data = [] # 用来保存该节点上的数据
self.left = None # 左右结点
self.right = None
def numOfGood(self):
return self.data[self.data == 1].sum()
def numOfBad(self):
return self.data[self.data == 0].sum()
def getDataSet():
"""
get watermelon data set 3.0 alpha.
:return: (feature array, label array)
"""
dataSet = np.array([
[0.697, 0.460, 1],
[0.774, 0.376, 1],
[0.634, 0.264, 1],
[0.608, 0.318, 1],
[0.556, 0.215, 1],
[0.403, 0.237, 1],
[0.481, 0.149, 1],
[0.437, 0.211, 1],
[0.666, 0.091, 0],
[0.243, 0.267, 0],
[0.245, 0.057, 0],
[0.343, 0.099, 0],
[0.639, 0.161, 0],
[0.657, 0.198, 0],
[0.360, 0.370, 0],
[0.593, 0.042, 0],
[0.719, 0.103, 0]
])
return dataSet
def calEntropy(dataArr, labelArr):
"""
calculate information entropy.
:param dataArr:
:return: entropy
"""
n = dataArr.size
data0 = dataArr[labelArr == 0]
data1 = dataArr[labelArr == 1]
p0 = data0.size / float(n)
p1 = data1.size / float(n)
# 约定:p=0, p*log_2(p) = 0
if p0 == 0:
entropy = -(p1 * np.log(p1))
elif p1 == 0:
entropy = -(p0 * np.log(p0))
else:
entropy = -(p0 * np.log(p0) + p1 * np.log(p1))
return entropy
# # test calEntropy()
# dataSet = getDataSet()
# print(calEntropy(dataSet[:, :-1], dataSet[:, -1],))
def calInfoGain(dataSet, feature, thresh):
"""
calculate information gain
:param dataSet: 数据集,最后一列为类别。
:param feature: 选择用来计算信息增益的特征。0表示第一个特征,1表示第二个特征
:param thresh: 用来划分数据的值
:return: 信息增益
"""
entD = calEntropy(dataSet[:, :-1], dataSet[:, -1]) # 计算D的原始信息熵
# 数据集划分为两份
dataL = dataSet[dataSet[:, feature] <= thresh]
dataR = dataSet[dataSet[:, feature] > thresh]
# 计算两数据集的信息熵
entL = calEntropy(dataL[:, :-1], dataL[:, -1])
entR = calEntropy(dataR[:, :-1], dataR[:, -1])
# 计算划分完总数据集的信息熵
entDS = (dataL.size * entL + dataR.size * entR) / float(dataSet.size)
# 计算信息增益
gain = entD - entDS
return gain
# # test calInfoGain(dataSet, feature, thresh)
# data = getDataSet()
# print(calInfoGain(data, 0, 0.6))
def chooseBestSplit(dataSet):
"""
计算信息增益增大的划分数据集的方式
:param dataSet:
:return: 信息增益最大的划分方式的 特征 和 划分值。
"""
maxGain = 0.0
bestFeature = -1
bestThresh = -1
m, n = dataSet.shape
# 对每一个特征
for i in range(n - 1):
feat = dataSet[:, i] # 得到第i个特征的所有值
sortedFeat = np.sort(feat) # 按照从小到大的顺序排列第i个特征的所有值
T = []
# 计算划分点
for j in range(m - 1):
t = (sortedFeat[j] + sortedFeat[j + 1]) / 2.0
T.append(t)
# 对每一个划分值,计算增益熵
for val in T:
gain = calInfoGain(dataSet, i, val)
if gain > maxGain:
bestFeature = i
bestThresh = val
maxGain = gain
return bestFeature, bestThresh, maxGain
# # test chooseBestSplit
# data = getDataSet()
# f, tv, g = chooseBestSplit(data)
# print(f"best feature is {f}\n"
# f"best thresh value is {tv}\n"
# f"max information gain is {g}")
def createTree(dataSet):
"""
通过信息增益创造一颗决策树
:param dataSet:
:return: 返回一颗树的根结点
"""
# 到叶子节点时返回。
# 若只剩k个相同类的数据,信息熵 = -(0*log_2(0) + k*log_2(k) = 0
# 即信息熵为0时返回叶子结点
if calEntropy(dataSet[:, :-1], dataSet[:, -1]) == 0:
leaf = TreeNode(-1, -1) # 构造叶子结点
leaf.label = dataSet[0][-1]
return leaf
feature, thresh, gain = chooseBestSplit(dataSet)
dataL = dataSet[dataSet[:, feature] <= thresh]
dataR = dataSet[dataSet[:, feature] > thresh]
Node = TreeNode(feature, thresh)
Node.data = dataSet
Node.left = createTree(dataL)
Node.right = createTree(dataR)
return Node
# # test createTree()
# data = getDataSet()
# Tree = createTree(data)
# ***********************画图***********************
# **********************start***********************
def getNumLeafs(myTree):
"""
得到叶子结点的数量
:param myTree:
:return:
"""
if myTree.feature == -1:
return 1
if myTree is None:
return 0
return getNumLeafs(myTree.left) + getNumLeafs(myTree.right)
# # test getNumLeafs()
# data = getDataSet()
# Tree = createTree(data)
# print(getNumLeafs(Tree)) # 5个叶子
def getTreeDepth(myTree):
"""
得到树的深度
:param myTree:
:return:
"""
if myTree is None:
return 0
# 1表示加上当前节点
depth = max(1 + getTreeDepth(myTree.left),
1 + getTreeDepth(myTree.right))
return depth
# # test getTreeDepth()
# data = getDataSet()
# Tree = createTree(data)
# print(getTreeDepth(Tree)) # 深度为5
# 没有这句的话画出的图上面汉字会显示成口口
mpl.rcParams['font.sans-serif'] = ['SimHei']
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
fList = ["密度", "含糖率"] # 后面画节点时用到
melon = ["坏瓜", "好瓜"]
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制一个结点以及指向这个减点的箭头
:param nodeTxt: 结点上的文字
:param centerPt: 箭头终止坐标
:param parentPt: 箭头起始坐标,即对应树中父结点坐标。即从parentPt指向centerPt
:param nodeType: 结点类型。实际上是一个字典,里面保存着绘制结点的参数,
decisionNode:表示非叶子结点。leafNode表示叶子结点、
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center", ha="center",
bbox=nodeType,
arrowprops=arrow_args,
fontsize=15) # 结点字的大小
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False)
# plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
# createPlot()
def plotMidText(cntrPt, parentPt, txtString):
"""
计算父节点和子节点中间的位置,即箭头中间的位置上画上文本,比如"是"和"否"
:param cntrPt: 子节点的坐标
:param parentPt:父节点的坐标
:param txtString: 要画的字符
:return:
"""
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString,
va="center", ha="center", rotation=30,
fontsize=15)
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
"""
递归画树
:param myTree: 树节点
:param parentPt: 父节点坐标
:param nodeTxt: 节点字符
:return:
"""
numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
depth = getTreeDepth(myTree)
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
if myTree.thresh != -1:
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(str(fList[myTree.feature]) # fList = ["密度", "含糖率"]
+ "<=" + str(myTree.thresh) + "?",
cntrPt, parentPt, decisionNode)
plotTree.yOff = plotTree.yOff - 1 / plotTree.totalD
else:
plotTree.xOff = plotTree.xOff + 1 / plotTree.totalW
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(melon[int(myTree.label)], cntrPt, parentPt, decisionNode) # melon = ["坏瓜", "好瓜"]
plotTree.yOff = plotTree.yOff - 1 / plotTree.totalD
if myTree.left is not None:
plotTree(myTree.left, cntrPt, "是")
if myTree.right is not None:
plotTree(myTree.right, cntrPt, "否")
plotTree.yOff = plotTree.yOff + 1 / plotTree.totalD
def createPlot(inTree):
"""
设置画图的基本信息,如树的宽度和深度,初始坐标等。调用plotTree()画图
:param inTree:
:return:
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# ***********************画图***********************
# ***********************end************************
def main():
data = getDataSet()
Tree = createTree(data)
createPlot(Tree)
if __name__ == '__main__':
main()