http://cn.akinator.com/ “神灯猜名人”这个游戏很多人都玩过吧,问很多问题,然后逐步猜测你想的名人是谁。决策树的工作原理与这个类似,输入一系列数据,然后给出游戏答案。决策树也是最经常使用的数据挖掘算法。书上给了一个流程图决策树,很简单易懂。
这里,椭圆形就是判断模块,方块就是终止模块。kNN 方法也可以完成分类任务,但是缺点是无法给出数据的内在含义。决策树的主要优势就在于数据形式容易理解。
==============================================================================
决策树
优点:计算复杂度不高,输出结果容易理解,对中间值缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型。
伪代码:
creatBranch():
if so return 类标签:
else:
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数 creatBranch() 并增加返回结果到分支节点中
return 分支节点
==============================================================================
决策树的一般流程:
一些决策树算法采用二分法,我们不用这种方法。我们可能会遇到更多的选项,比如四个,然后创立四个不同分支。本书将使用 ID3 算法划分数据集。
==============================================================================
信息增益:
划分数据集的大原则是:将无序的数据变得更加有序。
在划分数据集之前之后信息发生的变化称为信息增益。
集合信息的度量方式称为香农熵或者简称为熵。熵定义为信息的期望值。公式略过不表。
给一段代码,计算给定数据集的熵:
# -*- coding:utf-8 -*-
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): # 为所有可能分类创建字典
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob *log(prob, 2) # 以 2 为底求对数
return shannonEnt
然后自己利用 createDataSet() 函数来得到35页表 3-1 的鱼类鉴定数据集。
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
# -*- coding:utf-8 -*-
# run_trees.py
import trees
myDat,labels = trees.createDataSet()
print myDat
print trees.calcShannonEnt(myDat)
熵越多高,说明混合数据越多。这里添加一个 “maybe” 分类,表示可能为鱼类。
测试:
# -*- coding:utf-8 -*-
import trees
myDat,labels = trees.createDataSet()
print myDat
print trees.calcShannonEnt(myDat)
print '*********************************'
myDat[0][2] = 'maybe' # 0 指的是dataSet第一个[],-1 指[]里面倒数第一个元素
print myDat
print trees.calcShannonEnt(myDat)
====================================================================================================
3.1.2 划分数据集
分类算法除了需要测量信息熵,还要划分数据集。
我们对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的方式。
按照给定特征划分数据集,代码接着 trees.py 写:
def splitDataSet(dataSet, axis, value):
retDataSet= [] # 创建新的 list 对象
for featVec in dataSet:
if featVec[axis] == value:
# 抽取
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
>>>a = [1,2,3]
>>>b = [4,5,6]
>>>a.append(b)
>>>a
[1,2,3,[4,5,6]]
>>>a = [1,2,3]
>>>a.extend(b)
>>>a
[1,2,3,4,5,6]
=============================================================================================
现在可以在前面的简单样本数据上测试函数 splitDataSet()
在 run_trees.py 里面加些代码:
# -*- coding:utf-8 -*-
# run_trees.py
import trees
myDat,labels = trees.createDataSet()
print '>>> myDat'
print myDat
print '>>> trees.calcShannonEnt(myDat)'
print trees.calcShannonEnt(myDat)
print '*********************************'
myDat[0][2] = 'maybe' # 0 指的是dataSet第一个[],-1 指[]里面倒数第一个元素
print '>>> myDat'
print myDat
print '>>> trees.calcShannonEnt(myDat)'
print trees.calcShannonEnt(myDat)
print '*********************************'
reload(trees)
myDat,labels = trees.createDataSet()
print '>>> myDat'
print myDat
print '>>> trees.splitDataSet(myDat,0,1)'
print trees.splitDataSet(myDat,0,1)
print '>>> trees.splitDataSet(myDat,0,0)'
print trees.splitDataSet(myDat,0,0)
结果如下:
=======================================================================
接下来我们要遍历整个数据集,循环计算香农熵和 splitDataSet() 函数,找到最好的特征划分方式。熵计算将会告诉我们如何划分数据集是最好的组织方式。
加代码:
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures): # 遍历数据集中的所有特征
# 创建唯一的分类标签列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) # set 是一个集合
newEntropy = 0.0
for value in uniqueVals: # 遍历当前特征中的所有唯一属性值
# 计算每种划分方式的信息熵
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet) # 对所有唯一特征值得到的熵求和
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
# 计算最好的收益
bestInfoGain = infoGain
bestFeature = i
return bestFeature
在在 run_trees.py 里面加些代码:
print '*********************************'
reload(trees)
print '>>> myDat, labels = trees.createDataSet()'
myDat, labels = trees.createDataSet()
print '>>> trees.chooseBestFeatureToSplit(myDat)'
print trees.chooseBestFeatureToSplit(myDat)
print '>>> myDat'
print myDat
代码的意义在于,告诉我们第0个特征(不浮出水面是否可以生存)是最好的用于划分数据集的特征。
如果不相信这个结果,可以修改 calcShannonEnt(dataSet) 函数来测试不同特征分组的输出结果。
===============================================================================
3.1.3 递归构建决策树
从数据集构造决策树算法所需要的子功能模块,原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分后,数据将被乡下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。我们可以采用递归的原则处理数据集。
在添加代码前,在 trees.py 顶部加上一行代码:
import operator
def majorityCnt(classList):
classCount = {} # 创建键值为 classList 中唯一值的数据字典
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1 # 储存了 classList 中每个类标签出现的频率
sortedClassCount = sorted(classCount.iteritems(),\
key = operator.itemgetter(1), reverse = True) # 操作兼职排序字典
return sortedClassCount[0][0] # 返回出现次数最多的分类名称
# 创建树的函数代码
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList): # 类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet] # 得到列表包含的所有属性值
uniqueVals = set(featValues)
for value in uniqueVals:
# 为了保证每次调用函数 createTree() 时不改变原始列表类型,使用新变量 subLabels 代替原始列表
subLabels = labels[:] # 这行代码复制了类标签,并将其存储在新列表变量 subLabels 中
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet,bestFeat,value),subLabels)
return myTree
下一步开始创建树,使用 字典 类型来保存树的信息,当然也可以声明特殊的数据类型储存树,但是这里没有必要。
当前数据集选取的最好特征存储在变量 bestFeat 中,得到列表包含的所有属性值。
现在运行代码,在 run_trees.py 里面添加:
print '*********************************'
reload(trees)
print '>>> myDat, labels = trees.createDataSet()'
myDat, labels = trees.createDataSet()
print '>>> myTree = trees. createTree(myDat, labels)'
myTree = trees. createTree(myDat, labels)
print '>>> myTree'
print myTree
结果:
变量 myTree 包含了很多代表树结构信息的嵌套字典。