from math import log import numpy as np from operator import itemgetter # 计算信息熵 def calcShannonEnt(dataSet): # 数据集中总的记录个数 numEntires = len(dataSet) # 每种类型的个数 labelCounts = {} for featVec in dataSet: # 当前的标签类别 currentLabel = featVec[-1] # 统计每种类型个数 labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntires shannonEnt -= prob * log(prob, 2) return shannonEnt # 按照给定的特征划分数据集 # dataSet待划分的数据集 axis划分数据集的特征 value特征的返回值 def splitDataSet(dataSet, axis, value): reDataSet = [] for featVec in dataSet: if featVec[axis] == value: reduceFeatVec = featVec[:axis] reduceFeatVec.extend(featVec[axis + 1:]) # 去掉了划分数据特征即索引为axis的特征值 reDataSet.append(reduceFeatVec) return reDataSet def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataSet, labels def chooseBestFeatureToSplit(dataSet): # 属性的个数 numFeatures = len(dataSet[0]) - 1 # 信息熵 baseEntropy = calcShannonEnt(dataSet) # 默认最好的信息增益 bestInfoGain = 0.0 # 选中的划分属性index bestFeature = -1 for i in range(numFeatures): # 获取某个属性的所有取值 featList = [example[i] for example in dataSet] # 某属性可能的取值个数 uniqueVals = set(featList) newEntropy = 0.0 # 遍历某个特征的所有特征值 for value in uniqueVals: # 分割数据集 把相应特征对应的特征值数据找出来 subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) # 计算信息增益 newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature # 多数表决函数 def majorityCnt(classList): classCount = {} for vote in classList: classCount[vote] = classCount.get(vote, 0) + 1 sorted_classCount = sorted(classCount.items(), key=itemgetter(1), reverse=True) return sorted_classCount[0][0] def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] # count返回在列表中出现的次数 if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) # 获得信息增益最大的索引 bestFeat = chooseBestFeatureToSplit(dataSet) # 获取相应的标签 bestFeatLabels = labels[bestFeat] myTree = {bestFeatLabels: {}} del (labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabels][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree # 决策树的分类函数 类似于搜索操作 # inputTree待查找的树 # featLabels标签集合 # testVec 需要分类的集合数据 def classify(inputTree, featLabels, testVec): # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} # 属性特征 firstStr = list(inputTree.keys())[0] # 对应特征的分类值 secondDict = inputTree[firstStr] # 标签字符串所在的索引(即firstStr在featLabels上的位置序号) featIndex = featLabels.index(firstStr) # 对于特征属性每一个可能的取值 for key in secondDict.keys(): # 沿着决策树查找数据,如果key值匹配需要分类的集合数据 if testVec[featIndex] == key: # 如果相应的值是dict字典 继续向下搜索 if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) # 否则直接得出数据类别 else: classLabel = secondDict[key] return classLabel import pickle # 写入决策树到文件 def storeTree(inputTree, filename): fw = open(filename, 'wb+') pickle.dump(inputTree, fw) fw.close() # 加载文件内容 def grabTree(filename): fr = open(filename, 'rb') return pickle.load(fr) myDat, lables = createDataSet() updatlabels = lables[:] print(lables) myTree = createTree(myDat, lables) print(myTree) storeTree(myTree, "classfier.txt") ff = grabTree("classfier.txt") print("ff%s" % ff) # print(classify(myTree, updatlabels, [1, 0])) # print(classify(myTree, updatlabels, [1, 1])) # print(classify(myDat, lables, [1, 0])) # print(classify(myDat, lables, [1, 1])) # print(splitDataSet(myDat, 0, 1)) # print(splitDataSet(myDat, 0, 0)) # print(createTree(myDat, lables)) # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
决策树图像算法
import matplotlib.pyplot as plt # import decisions.createTree # 定义判断节点形态 decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义叶节点形态 leafNode = dict(boxstyle="round4", fc="0.8") # 定义箭头 arrow_args = dict(arrowstyle="<-") # 绘制带箭头的注释 # centerPt节点的中心位置 # parentPt节点的起始位置 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center" , ha="center", bbox=nodeType, arrowprops=arrow_args) def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() # fig = plt.figure(1, facecolor='white') # fig.clf() # createPlot.ax1 = plt.subplot(111, frameon=False) # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) # plt.show() # 获取叶节点个数 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}} print('the tree depth is %s' % getTreeDepth(myTree)) print('the number of leaf are %s' % getNumLeafs(myTree)) createPlot(myTree) # createPlot()