一、决策树介绍
决策树是一种树模型,从根节点开始一步步走到叶子节点(决策过程),所有的数据最终都会落到叶子节点,这种算法既可以做分类也可以做回归。
决策树的组成:结点和有向边。结点的类型又可以分成三种:根结点(第一个选择的分支的属性)、中间节点(继根节点后的非叶子结点)、叶子结点(表示最终的决策结果)
在本次案例主要讲述分类决策树模型。
二、构造决策树的基本流程
算法基本流程:
- 将所有数据放在根节点
- 选择一个最优的特征,根据这个特征将训练数据分割成子集,使得各个子集在当前条件下有一个最好的分类
- 递归下去,直到所有数据子集都被基本正确分类、或者没有合适的特征为止
- 递归返回的三个条件:
- (1)当前结点点包含的样本全部属于同一类别
- (2)当前属性集为空,或者是所有样本在所有属性的取值均相同,无法划分
- (3)当前结点包含的样本集合为空
三、最优特征的选择
前言:决策树学习的关键在于如何选择最优划分属性。一般而言随着划分过程不断进行,我们希望决策树的分支结点 所包含的样本尽可能属于同一类别,即结点的“纯度 ”越来越高。
衡量样本集合纯度---信息熵熵信息:熵是表示随机变量不确定性的度量,即物体内部的混乱程度。假定当前样本集合D中第k类样本所占的比例为pk(k = 1,2,...,|Y|),则D的信息熵定义为:Ent(D)的值越小,则D的纯度越高
一个栗子:A集合[1,1,1,1,1,1,1,1,2,2]B集合[1,2,3,4,5,6,7,8,9,1]显然A集合的熵值要低,因为A里面只有两种类别,相对稳定一些。 而B中类别太多了,熵值就会大很多。如何决策一个节点的选择呢?
四、python实现决策树
这里的数据集取的是毕业生薪资等级的数据集。
professional:专业编号;gender:性别(1:男;0:女);age:年龄;socialSkill:社交能力;professionalSkill:专业能力;
isJob:薪资等级标签
这里只抽取了前十个训练样本。
4.1数据集准备:
def createDataSet(): data = pd.read_csv('movie.csv') print(np.array(data[0:10])) # 将pandas转化成list dataSet = np.array(data[0:10]).tolist() print(type(dataSet)) labels = ['professional', 'gender', 'age', 'socialSkill', 'professionalSkill'] return dataSet, labels
4.2 计算信息熵
# 计算数据集的信息熵 Ent(D) def calcShannonEnt(dataset): numexamples = len(dataset) labelCounts = {} # 利用字典对标签的类别进行统计 for featVec in dataset: currentlabel = featVec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 labelCounts[currentlabel] += 1 # 依照信息熵的公式求 Ent(D) shannonEnt = 0 for key in labelCounts: prop = float(labelCounts[key]) / numexamples shannonEnt -= prop * log(prop, 2) return shannonEnt
4.3 根据特征划分数据集
# 根据给定规则划分数据集 def splitDataSet(dataset, axis, val): retDataSet = [] for featVec in dataset: # 依据给定特征匹配,划分数据集 if featVec[axis] == val: # reducedFeatVec存储去掉axis特征的集合 reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis + 1:]) retDataSet.append(reducedFeatVec) return retDataSet
4.4 计算信息增益,并返回最大的信息增益
# 依照最大信息增益的 feature,选择最优特征 def chooseBestFeatureToSplit(dataset): # 特征数量 numFeatures = len(dataset[0]) - 1 baseEntropy = calcShannonEnt(dataset) # 初始化信息增益 bestInfoGain = 0 bestFeature = -1 for i in range(numFeatures): # 获取dataset的第i列所有特征 featList = [example[i] for example in dataset] # print(featList) # list转为集合 set,set中的元素不可以重复 uniqueVals = set(featList) newEntropy = 0 # 依据特征划分数据集,根据公式计算对应特征的信息增益 for val in uniqueVals: subDataSet = splitDataSet(dataset, i, val) # print(subDataSet) # print(len(subDataSet)) prob = len(subDataSet) / float(len(dataset)) # print(prob) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy # 选出最大的信息增益以及对应的特征 if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
4.5 利用递归的方法构造决策树
# 创建决策树 def createTree(dataset, labels, featLabels): # 数据集的标签 classList = [example[-1] for example in dataset] print(classList) # 如果要进分类的标签相同,直接返回,无需分类 if classList.count(classList[0]) == len(classList): return classList[0] if len(dataset[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataset) bestFeatLabel = labels[bestFeat] featLabels.append(bestFeatLabel) myTree = {bestFeatLabel: {}} del labels[bestFeat] featValue = [example[bestFeat] for example in dataset] uniqueVals = set(featValue) for value in uniqueVals: sublabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset, bestFeat, value), sublabels, featLabels) return myTree # 统计classList中出现最多的元素 def majorityCnt(classList): classCount = {} # 统计classList中每个元素出现的次数 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 # 字典降序排列 sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedclassCount[0][0]
4.6使用matplotlibt绘制决策树
# 获得树的叶子节点数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = next(iter(myTree)) secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 获得树的深度 def getTreeDepth(myTree): maxDepth = 0 firstStr = next(iter(myTree)) secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): arrow_args = dict(arrowstyle="<-") font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14) createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font) # 绘制划分属性 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) # 绘制决策树 def plotTree(myTree, parentPt, nodeTxt): decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = next(iter(myTree)) cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') # 创建fig fig.clf() # 清空fig axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 去掉x、y轴 plotTree.totalW = float(getNumLeafs(inTree)) # 获取决策树叶结点数目 plotTree.totalD = float(getTreeDepth(inTree)) # 获取决策树层数 plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0; # x偏移 plotTree(inTree, (0.5, 1.0), '') # 绘制决策树 plt.show()
完整代码:
# -*- coding: UTF-8 -*- import numpy as np import pandas as pd from matplotlib.font_manager import FontProperties import matplotlib.pyplot as plt from math import log import operator def createDataSet(): data = pd.read_csv('movie.csv') print(np.array(data[0:10])) # 将pandas转化成list dataSet = np.array(data[0:10]).tolist() print(type(dataSet)) labels = ['professional', 'gender', 'age', 'socialSkill', 'professionalSkill'] return dataSet, labels # 计算数据集的信息熵 Ent(D) def calcShannonEnt(dataset): numexamples = len(dataset) labelCounts = {} # 利用字典对标签的类别进行统计 for featVec in dataset: currentlabel = featVec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 labelCounts[currentlabel] += 1 # 依照信息熵的公式求 Ent(D) shannonEnt = 0 for key in labelCounts: prop = float(labelCounts[key]) / numexamples shannonEnt -= prop * log(prop, 2) return shannonEnt # 依照最大信息增益的 feature,选择最优特征 def chooseBestFeatureToSplit(dataset): # 特征数量 numFeatures = len(dataset[0]) - 1 baseEntropy = calcShannonEnt(dataset) # 初始化信息增益 bestInfoGain = 0 bestFeature = -1 for i in range(numFeatures): # 获取dataset的第i列所有特征 featList = [example[i] for example in dataset] # print(featList) # list转为集合 set,set中的元素不可以重复 uniqueVals = set(featList) newEntropy = 0 # 依据特征划分数据集,根据公式计算对应特征的信息增益 for val in uniqueVals: subDataSet = splitDataSet(dataset, i, val) # print(subDataSet) # print(len(subDataSet)) prob = len(subDataSet) / float(len(dataset)) # print(prob) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy # 选出最大的信息增益以及对应的特征 if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature # 根据给定规则划分数据集 def splitDataSet(dataset, axis, val): retDataSet = [] for featVec in dataset: # 依据给定特征匹配,划分数据集 if featVec[axis] == val: # reducedFeatVec存储去掉axis特征的集合 reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis + 1:]) retDataSet.append(reducedFeatVec) return retDataSet # 创建决策树 def createTree(dataset, labels, featLabels): # 数据集的标签 classList = [example[-1] for example in dataset] print(classList) # 如果要进分类的标签相同,直接返回,无需分类 if classList.count(classList[0]) == len(classList): return classList[0] if len(dataset[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataset) bestFeatLabel = labels[bestFeat] featLabels.append(bestFeatLabel) myTree = {bestFeatLabel: {}} del labels[bestFeat] featValue = [example[bestFeat] for example in dataset] uniqueVals = set(featValue) for value in uniqueVals: sublabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset, bestFeat, value), sublabels, featLabels) return myTree # 统计classList中出现最多的元素 def majorityCnt(classList): classCount = {} # 统计classList中每个元素出现的次数 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 # 字典降序排列 sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedclassCount[0][0] # 获得树的叶子节点数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = next(iter(myTree)) secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 获得树的深度 def getTreeDepth(myTree): maxDepth = 0 firstStr = next(iter(myTree)) secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): arrow_args = dict(arrowstyle="<-") font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14) createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font) # 绘制划分属性 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) # 绘制决策树 def plotTree(myTree, parentPt, nodeTxt): decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = next(iter(myTree)) cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') # 创建fig fig.clf() # 清空fig axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 去掉x、y轴 plotTree.totalW = float(getNumLeafs(inTree)) # 获取决策树叶结点数目 plotTree.totalD = float(getTreeDepth(inTree)) # 获取决策树层数 plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0; # x偏移 plotTree(inTree, (0.5, 1.0), '') # 绘制决策树 plt.show() if __name__ == '__main__': dataset, labels = createDataSet() featLabels = [] myTree = createTree(dataset, labels, featLabels) createPlot(myTree)
运行结果:
出现的原因:信息增益有个缺点就是对可取数值多的属性有偏好,举个例子讲,还是考虑西瓜数据集,如果我们把“编号”这一列当做属性也考虑在内,那么可以计算出它的信息增益为0.998,远远大于其他的候选属性,因为“编号”有17个可取的数值,产生17个分支,每个分支结点仅包含一个样本,显然这些分支结点的纯度最大。但是,这样的决策树不具有任何泛化能力。而在我所选择的数据集中,前十条数据socialSkill中就有9条数据不一样,导致生成了以可泛化能力不强的决策树。
在更改了数据集之后的决策树