目录
一、决策树概述
1.1 什么是决策树
1.2 决策树分类实例
二、决策树的构造
2.1特征选择
2.1.1信息增益(ID3)
2.1.2 信息增益率
2.1.3 基尼指数
2.2 生成决策树
2.2.1 划分数据集
2.2.2 递归构建决策树
2.2.3 决策树的可视化
决策树(decision tree) 是一类常见的用于基本的分类与回归的机器学习方法。顾名思义,决策树是基于树结构来进行决策的。在分类问题中,表示基于特征对实例进行分类的过程。它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。
决策树是一种描述对实例进行分类的树形结构。一般的,一棵决策树包含一个根结点、若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试;每个结点包含的样本集合根据属性测试的结果被划分到子结点中;根结点包含样本全集。从根结点到每个叶结点的路径对应了一个判定测试序列。
决策树是基于树结构来进行决策的,这恰是人类在面临决策问题时一种很自然的处理机制。
例如,我们要对“今天是否出去玩”这样的问题进行决策时,通常会进行一系列的判断或"子决策":我们先看今天的天气是晴天、多云还是下雨,如果是晴天再判断温度高还是正常,如果高就不出去,正常就出去;如果是多云就出去;如果是下雨再判断风强不强,风强就不出去,风弱就出去。基于上述思路,可构造一个如图所示决策树:
由图可以看出决策树由结点和有向边组成。结点有两种类型:内部结点和叶结点,内部结点表示一个特征或属性,叶结点表示一个类。从根结点到叶结点的每一条路径构成一条规则:路径上的内部结点对应着规则的条件,叶结点的类对应着规则的结论。
决策树学习的目的是为了产生一棵泛化能力强, 即处理未见示例能力强的决策树。
决策树是一种典型的分类方法:
- 首先对数据进行处理,利用归纳算法生成可读的规则和决策树
- 然后使用决策对新数据进行分析
本质上决策树是通过一系列规则对数据进行分类的过程。
决策树学习的关键在于如何选择最优划分属性。一般而言,随着划分过程不断进行,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的“纯度”越来越高
经典的属性划分方法:
信息熵
信息熵是度量样本集合纯度最常用的一种指标,假定当前样本集合D中第k类样本所占的比例为 pk (K=1, 2, .. |y|) ,则D的信息熵定义为:
信息增益
离散属性a有V个可能的取值{ , , ..., },用a来进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为 的样本,记为 。则可计算出用属性a对样本集D进行划分所获得的信息增益:
ID3(Iterative Dichotomiser,迭代二分器)决策树学习算法[Quinlan, 1986] 以信息增益为准则来选择划分属性
信息增益的算法
- 设训练数据集为D
- |D|表示其样本容量,即样本个数
- 设有K个类,k=1,2...K,||为属于类的样本个数
- 特征A有n个不同的取值{},根据特征A的取值将D划分为n个子集
- 为的样本个数
- 记子集中属于类的样本集合为,为的样本 个数
输入:训练数据集和特征A
输出:特征A对训练数据集D的信息增益G(D,A)
- 信息熵为:
- 计算信息增益:
以信息增益作为划分训练数据集的特征,存在偏向于选择取值较多的特征。
如上图,若把“编号”作为一个候选划分属性:
编号的信息增益远大于其他属性。显然,这样的决策树不具有泛化能力,无法对新样本进行有效预测。
使用信息增益率可以对这一问题进行校正。
信息增益率:特征a对训练数据集D的信息增益比定义为信息增益与训练数据集D关于特征A的值的熵之比:
其中
称为属性a的“固有值”(intrinsic value) [Quinlan, 1993]. 属性a的可能取值数目越多(即V越大),则IV(a) 的值通常会越大。
信息增益率对属性值较少的属性比较偏好。
C4.5算法并不是直接选择增益率最大的作为候选划分属性,而是使用一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性(保证属性值多的属性被选择),再从中选择增益率最高的(在属性值多的属性中挑属性值少的属性),这样是保证了相对准确。
定义:分类问题中,假设D有K个类,样本点属于第k类的概率为, 则概率分布的基尼值定义为:
Gini(D)越小,数据集D的纯度越高
给定数据集D,属性a的基尼指数定义为:
在候选属性集合A中,选择那个使得划分后基尼指数最小的属性作为最有划分属性。
上面一共介绍了三种方法进行特征选择,以信息增益的示例来划分属性:
按照给定特征划分数据集:
'''
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
'''
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
计算给定数据集的信息熵:
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #Label计数
shannonEnt = 0.0 #经验熵(香农熵)
for key in labelCounts: #计算香农熵
prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) #利用公式计算
return shannonEnt #返回经验熵(香农熵)
选择最优特征进行划分:
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
显然,决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会
导致递归返回:
代码实现如下:
#统计出现次数最多的元素(类标签)
def majorityCnt(classList):
classCount = {}
for vote in classList: #统计classList中每个元素出现的次数
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序
return sortedClassCount[0][0] #返回classList中出现次数最多的元素
创建决策树:
'''
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
'''
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] #取分类标签(是否找到工作:yes or no)
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
创建测试数据集(构建平均绩点,竞赛等级和实习经历判断是否能找到工作:平均绩点手动离散化为三个等级;竞赛等级:0表示未获奖,1表示获得校级奖项,2表示获得省级及以上奖项;实习经历:1表示有实习经历,0表示没有;结果“yes”表示找到工作,“no”表示未找到工作)
数据集如图:
代码如下:
if __name__ == '__main__':
data = pd.read_csv("D:/syy/MachineLearning/data/data_td.csv")
dataSet = data.values.tolist()
labels = ['平均绩点', '竞赛等级', '实习经历']
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
print(myTree)
运行结果如下:
import matplotlib
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="square", fc="0.8") #boxstyle文本框样式、fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8") #叶子节点
arrow_args = dict(arrowstyle="<-") #定义箭头
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#createPlot.ax1是表示: ax1是函数createPlot的一个属性
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt,
textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs = 0 # 初始化
firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)
secondDict = myTree[firstStr] # 获得value值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典
numLeafs += getNumLeafs(secondDict[key]) # 递归调用
else:
numLeafs += 1
return numLeafs
# 获取树的深度
def getTreeDepth(myTree):
maxDepth = 0 # 初始化
firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)
secondDict = myTree[firstStr] # 获得value值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典
thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 画树
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 获取树高
depth = getTreeDepth(myTree) # 获取树深度
firstStr = list(myTree.keys())[0] # 这个节点的文本标签
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) #plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置
plotMidText(cntrPt, parentPt, nodeTxt) #标记子节点属性
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD #减少y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 绘制决策树
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 创建一个新图形
fig.clf() # 清空绘图区
font = {'family': 'MicroSoft YaHei'}
matplotlib.rc("font", **font)
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
测试结果
createPlot(myTree)
运行结果如下:
结果分析:从决策树可以看出平均绩点对能否找到工作影响最大。在平均绩点处于中等时,再根据竞赛等级判断是否能找到工作;最后根据是否有实习经历判断。
完整代码链接:链接: https://pan.baidu.com/s/1mut3VgS0aPp5kyy8x5z9ZQ 提取码: diqb