决策树学习1)采用自顶向下的递归方法
2)基本思想是以信息熵为度量,向下构造一颗熵值下降最快的树,到叶子结点处熵值为0.
3)属于有监督学习
决策树算法历史1)Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法
2)Breiman等人在1984年提出的CART算法。
决策树呈树形结构,在分类问题中表示基于特征对实例进行分类的过程。
决策树的分类:1)离散型决策树:目标变量为离散型(CLS,ID3,C45)
2)连续型决策树:目标变量是连续型(CART)
决策树的构造过程:
1)特征选择:从若干特征中选择一个特征作为当前节点分裂的标准。
方法:ID3(信息增益)
C4.5(信息增益比)
CART(Gini基尼指数)
2)决策树的生成
根据选择特征评估标准,从上到下递归地生成子节点,直到数据集不可分。目标是使某个特征划分后各个子集纯度更高,不确定性更小。
3)决策树的裁剪
决策树容易过拟合(over-fitting)通过剪枝来缩小结构规模、缓解过拟合。
剪枝方法有:预剪枝:在结点划分前进行预判断,如果划分后能够使子集纯度更纯则进行,反之不进行。
后剪枝:先生成一棵完整的树,然后自底向上对非叶结点考察是否替换子树为叶节点。
决策树的优缺点:
优点:可读性强,分类速度快;
缺点:容易出现过拟合,对未知的测试数据结果不一定好。可采用剪枝或者随机森林。
ID3算法
1)决策树中的每一个非叶子结点对应一个特征属性,树枝代表这个属性的值,叶节点代表最终分类属性值。
2)每一个非叶子结点与属性中具有最大信息量的特征属性相关联。
3)熵通常用于测量一个非叶子结点的信息量大小。
实现步骤:
1、创建数据集
2、createTree创建决策树
1)判断生成叶子结点/结点
2)选择最佳属性划分方式:选择最大信息增益(重点)
emmmmmmmmmm---------------------------------
上python37代码
创建trees.py
#!/usr/bin/python # -*- coding: UTF-8 -*- from math import log import operator # 划分数据集,axis:按第几个属性划分,value:要返回的子集对应的属性值 def splitDataSet(dataSet,axis,value): retDataSet=[] featVec=[] for featVec in dataSet: if featVec[axis]==value: #将featVec[axis]单独分出去,剩下的数据集搞到reducedFeatVec中 reducedFeatVec=featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet # 计算信息熵 def calcShannonEnt(dataSet): numEntries=len(dataSet)# 样本数 labelCounts={} for featVec in dataSet:# 遍历每个样本 currentLabel=featVec[-1]# 当前样本的类别 if currentLabel not in labelCounts.keys():# 生成类别字典 labelCounts[currentLabel]=0#初始化新类别中0个样本 labelCounts[currentLabel]+=1#对当前样本类别进行计数 shannonEnt=0.0 for key in labelCounts:#计算信息熵 prob=float(labelCounts[key])/numEntries shannonEnt=shannonEnt-prob*log(prob,2) return shannonEnt # 选择最好的数据集划分方式 def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1 # 属性个数 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0 bestFeature=-1 for i in range(numFeatures):#对每个属性计算信息增益 featList=[example[i] for example in dataSet] uniqueVals=set(featList)#该属性的取值集合 newEntropy=0.0 for value in uniqueVals:#对每一种取值计算信息增益 subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet) infoGain=baseEntropy-newEntropy if (infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature # 递归构建决策树 # 通过排序返回出现次数最多的类别 def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys():classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.items(), key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] # 递归构建决策树 def createTree(dataSet,labels): classList=[example[-1] for example in dataSet]#类别向量 if classList.count(classList[0])==len(classList):#如果只有一个类别返回 return classList[0] if len(dataSet[0])==1:#如果所有特征都被遍历完了,返回出现次数最多的类别 return majorityCnt(classList) bestFeat=chooseBestFeatureToSplit(dataSet) bestFeatLabel=labels[bestFeat]#最优划分属性的标签 myTree={bestFeatLabel:{}} del (labels[bestFeat])#已经选择的特征不再参与分类 featValues=[example[bestFeat]for example in dataSet] uniqueValue=set(featValues)#该属性所有可能取值,节点的分支 for value in uniqueValue: subLabels=labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree 创建Plotter.py
# -*- coding: cp936 -*- import matplotlib.pyplot as plt # 设置决策节点和叶节点的边框形状、边距和透明度,以及箭头的形状 decisionNode = dict(boxstyle="square,pad=0.5", fc="0.9") leafNode = dict(boxstyle="round4, pad=0.5", fc="0.9") arrow_args = dict(arrowstyle="<-", connectionstyle="arc3", shrinkA=0, shrinkB=16) # 获得树的叶子结点数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0]# 获得当前第一个根节点 secondDict = myTree[firstStr] # 获取该根下的子树 for key in secondDict.keys(): # 获得所有子树的根节点进行遍历 if type(secondDict[key]).__name__ == 'dict': # 如果子节点是dict类型则不是子节点需要继续遍历 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 获得树的深度 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0]# 获得当前第一个根节点 secondDict = myTree[firstStr] # 获取该根下的子树 for key in secondDict.keys(): # 获取所有子树节点,进行遍历 if type(secondDict[key]).__name__ == 'dict':# 如果子树类型为dict则不是叶子结点 thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 计算父节点到子节点的中点坐标,在该点上标注txt信息 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() # 给createPlot子节点绘图添加注释 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
测试数据集
#!/usr/bin/python # -*- coding: UTF-8 -*- import ID3 import json import Plotter fr = open(r'C:\Users\LMQ\untitled\activityData.txt') listWm = [inst.strip().split('\t') for inst in fr.readlines()] labels = ['天气', '温度', '湿度', '风速'] Trees = ID3.createTree(listWm, labels) print(json.dumps(Trees, ensure_ascii=False)) Plotter.createPlot(Trees)
PS:最终在Pycharm上生成的决策图上仍然没有注释