treePlotter模块

之前以为treePlotter是一个待安装的库,后来总是安装不成功,在学习过程中发现它其实就是一系列函数组成的自定义模块,下面介绍该模块的代码以及怎么使用该模块。
第一步:新建一个python包,在__init__文件中键入以下代码:
# _*_ coding: UTF-8 _*_

import matplotlib.pyplot as plt


"""绘决策树的函数"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 定义分支点的样式
leafNode = dict(boxstyle="round4", fc="0.8")  # 定义叶节点的样式
arrow_args = dict(arrowstyle="<-")  # 定义箭头标识样式


# 计算树的叶子节点数量
def getNumLeafs(myTree):
   numLeafs = 0
   firstStr = list(myTree.keys())[0]
   secondDict = myTree[firstStr]
   for key in secondDict.keys():
      if type(secondDict[key]).__name__ == 'dict':
         numLeafs += getNumLeafs(secondDict[key])
      else:
         numLeafs += 1
   return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
   maxDepth = 0
   firstStr = list(myTree.keys())[0]
   secondDict = myTree[firstStr]
   for key in secondDict.keys():
      if type(secondDict[key]).__name__ == 'dict':
         thisDepth = 1 + getTreeDepth(secondDict[key])
      else:
         thisDepth = 1
      if thisDepth > maxDepth:
         maxDepth = thisDepth
   return maxDepth


# 画出节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
   createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                           xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
                           bbox=nodeType, arrowprops=arrow_args)


# 标箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
   lens = len(txtString)
   xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
   yMid = (parentPt[1] + cntrPt[1]) / 2.0
   createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
   numLeafs = getNumLeafs(myTree)
   depth = getTreeDepth(myTree)
   firstStr = list(myTree.keys())[0]
   cntrPt = (plotTree.x0ff + \
             (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
   plotMidText(cntrPt, parentPt, nodeTxt)
   plotNode(firstStr, cntrPt, parentPt, decisionNode)
   secondDict = myTree[firstStr]
   plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
   for key in secondDict.keys():
      if type(secondDict[key]).__name__ == 'dict':
         plotTree(secondDict[key], cntrPt, str(key))
      else:
         plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
         plotNode(secondDict[key], \
                  (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
         plotMidText((plotTree.x0ff, plotTree.y0ff) \
                     , cntrPt, str(key))
   plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
   fig = plt.figure(1, facecolor='white')
   fig.clf()
   axprops = dict(xticks=[], yticks=[])
   createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
   plotTree.totalW = float(getNumLeafs(inTree))
   plotTree.totalD = float(getTreeDepth(inTree))
   plotTree.x0ff = -0.5 / plotTree.totalW
   plotTree.y0ff = 1.0
   plotTree(inTree, (0.5, 1.0), '')
   plt.show()

if __name__=='__main__':
    createPlot()
参考廖雪峰教程,最后两行代码的意思是:当我们在命令行运行模块文件时,Python解释器把一个特殊变量 __name__ 置为 __main__ ,而如果在其他地方导入该模块时, if 判断将失败,因此,这种 if 测试可以让一个模块通过命令行运行时执行一些额外的代码,最常见的就是运行测试。
第二步:关于如何导入treePlotter
参考博客: Python3导入自定义模块的3种方式_pwc1996的博客-CSDN博客
这里我是直接将模块所在的文件夹放在运行程序文件夹下,这样只需要在待运行程序中直接import即可。
注意事项:
1、要让某个文件成为模块的话,在其目录下必须有一个__init__.py的文件
2、创建自己的模块时,要注意:
(1)模块名要遵循Python变量命名规范,不要使用中文、特殊字符;
(2)模块名不要和系统模块名冲突,最好先查看系统是否已存在该模块,检查方法是在Python交互环境执行import abc,若成功则说明系统存在此模块。
第三步:演示结果
treePlotter模块_第1张图片

你可能感兴趣的:(ML,python,treePlotter,Decision,tree,1024程序员节)