返回目录
上一篇:k-近邻算法
决策树的类型有很多,有CART、ID3和C4.5等,其中CART是基于基尼不纯度(Gini)的,这里不做详解,而ID3和C4.5都是基于信息熵的,它们两个得到的结果都是一样的,本次定义主要针对ID3算法。下面我们介绍信息熵的定义。
设D为用类别对训练集进行的划分,则D的熵(entropy)表示为:
其中m表示训练集中标签种类的个数;pi表示第i个类别在整个训练集中出现的概率,可以用属于此类别元素的数量除以训练集合元素总数量作为估计;-log2(p(i))表示为事件i的不确定程度,称为i的自信息量。熵的实际意义表示是D中训练集的标签所需要的平均信息量。
现在我们假设将训练集D按特征A进行划分,则A对D划分的期望信息为:
其中v表示特征A所取值的个数,|Dj|表示当特征A为j时的训练集元素的个数,|D|表示训练集所有元素的总数。
则用特征A划分训练集D后所得的信息增益(gain)为:
从信息论知识中我们知道,期望信息越小,信息增益越大,从而纯度越高。所以ID3算法的核心思想就是以信息增益度量特征选择,选择分裂后信息增益最大的特征进行分裂。下面我们继续用SNS社区中不真实账号检测的例子说明如何使用ID3算法构造决策树。为了简单起见,我们假设训练集合包含10个元素。
其中s、m、l分别表示小、中、大。
先计算总的信息熵。总共有10个元素,标签类别有两种,一种是“yes”(7条记录),表示账号为真实,另一种是“no”(3条记录),表示账号不真实。则总的熵为:
设L、F、H和R表示日志密度、好友密度、是否使用真实头像和账号是否真实,下面分别计算各特征的信息增益。
以计算日志密度为例,对于特征日志密度,其取值有s(3条记录)、m(4条记录)和l(3条记录)三种;在日志密度为s的条件下,账号为真实的记录有1条,账号不真实的记录有2条;在日志密度为m的条件下,账号为真实的记录有3条,账号不真实的记录有1条;在日志密度为l的条件下,账号为真实的记录有3条,账号不真实的记录有0条。
用特征日志密度划分训练集D后所得的信息增益(gain)为:
同理,可计算出gain(F)=0.553,gain(H)=0.033。因为gain(F)最大,所以第一次分裂选择以F(好友密度)为分裂特征,分裂后的结果如下:
绿色结点表示判断条件,红色节点表示决策结果。
在上图的基础上,再递归使用这个方法计算子节点的分裂特征,最终就可以得到整个决策树。
优点:输出结果易于理解,对中间值的确实不敏感
缺点:容易产生过拟合
首先需要实现计算一个给定数据集的熵。
from math import log
# 计算给定数据集的熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: #the the number of unique elements and their occurance
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #log base 2
return shannonEnt
测试该方法:
# 创建数据集
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
#change to discrete values
return dataSet, labels
myDat,labels = createDataSet()
print 'myDat:',myDat
print 'entropy_myDat:',calcShannonEnt(myDat)
运行结果:
myDat: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
entropy_myDat: 0.970950594455
我们计算数据集的熵是为了让分裂数据集。第一步需要找到数据集的分裂特征,第二步需要根据找到的分裂特征对数据集进行分裂。这儿先讨论第二个问题,也就是说假定已经找到了分裂特征,如何根据它来分裂数据集呢?
# 根据指定的特征来分裂数据集
# dataSet:数据集(MxN),axis:特征的索引,即第几个特征:,value:所选特征的取值
# 返回一个数据集,该数据集以axis索引表示的特征为分裂特征,并且该分裂特征的值为value时得到的。
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
测试该方法,使用先前的数据集:
print '以第0个特征为分裂特征进行分裂数据集,'
print '分裂特征值为1的子集合:',splitDataSet(myDat,0,1)
print '分裂特征值为0的子集合:',splitDataSet(myDat,0,0)
运行结果:
以第0个特征为分裂特征进行分裂数据集,
分裂特征值为1的子集合: [[1, 'yes'], [1, 'yes'], [0, 'no']]
分裂特征值为0的子集合: [[1, 'no'], [1, 'no']]
下来要做的是找到分裂特征:
# 返回分裂特征的索引
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1# 数据集中元素的最后一列为类别标签,所以需减1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0;bestFeature = -1# 初始化
for i in range(numFeatures):
featList = [element[i] for element in dataSet]# 得到数据集中第i个特征的所有取值
uniqueVals = set(featList)# 对featList去重,得到第i个特征的特征值集合
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
这里需要注意的是,传入的数据集需要满足以下条件:
1)数据必须是由列表元素组成的列表,且所有的列表元素都必须具有相同的数据长度
2)数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。
测试该方法:
print '分裂特征的索引为:',chooseBestFeatureToSplit(myDat)
运行结果:
print '分裂特征的索引为:',chooseBestFeatureToSplit(myDat)
分裂特征的索引为: 0
现在已经能够对数据集按照分裂特征来分裂,接下来要完成的是通过不断分裂数据集来构成决策树。可以使用递归来完成,递归结束的条件是:遍历完所有子数据集的特征,或者每个分支下的所有实例都具有相同的分类,如果数据集已经处理完了所有特征,但是类标签依然不是唯一的,这时我们通常采用多数表决的方法决定该叶子结点的分类。
我们先实现下多数表决的代码:
import operator
#多数表决
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
接下来实现创建决策树代码:
# 创建决策树
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]# 类别完全相同则停止继续划分
if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
测试下上面的效果:
myTree= createTree(myDat,labels)
print myTree
运行结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
前面我们已经依靠训练数据构造了决策树,现在需要用它来进行实际数据的分类。在执行分类时,需要决策树以及构造决策树的标签向量。程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子结点;最后将测试数据定义为叶子结点所属的类型。
#使用决策树执行分类
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr) #index方法查找当前列表中第一个匹配firstStr变量的元素的索引
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
测试结果:
myDat,labels = createDataSet()
print 'myTree:',myTree
print 'labels:',labels
print classify(myTree,labels,[1,0])
print classify(myTree,labels,[1,1])
classify也可以这么实现:
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
由于构造决策树是很耗时的,所以可以考虑将创建好的决策树存储到硬盘上。
#决策树的存储
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
测试:
storeTree(myTree,'mytree.txt')
grabTree('mytree.txt')
输出:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
下一篇:朴素贝叶斯算法
想学人工智能(Python、数据分析、机器学习、深度学习、推荐系统、强化学习),来公众号AI派看看吧!!