刚开始学习机器学习,在熟悉python语法之后,看了李航的统计学习方法决策树的原理,自己推导了一遍,想用代码实现,但是无从下手。
跟着机器学习实战的代码写了一遍,这本书的代码注释不多,不易理解,这里添加了理解后的详细注释。一起学习!
ID3算法,根据最大信息增益原理选择最优特征。
下面附代码:
from math import log
import operator
def calcShannonEnt(dataSet):
"""计算给定数据集的熵"""
numEnttries = len(dataSet)#计算数据集实例总数
labelCounts = {}#保存各分类实例的数目
for featVec in dataSet:
currentLable = featVec[-1]
if currentLable not in labelCounts.keys():
labelCounts[currentLable] = 0
labelCounts[currentLable] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEnttries#计算各分类的概率
shannonEnt -= prob * log(prob,2)#熵
return shannonEnt
def splitDataSet(dataSet,axis,value):#axis:第axis个特征,value:第axis个特征的值
"""按照给定特征划分数据集"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:#将需要提取的特征数据集提取出来
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
"""选择最好的数据集划分方式"""
numFeatrues = len(dataSet[0])-1 #计算实例包含的数据特征
baseEntropy = calcShannonEnt(dataSet) #计算数据集的熵
bestInfoGain = 0.0 #信息增益
bestFeatrue = -1 #最优划分特征
for i in range(numFeatrues):
featList = [example[i] for example in dataSet] #dataSet是二维列表,exampl是一维列表,将列表example中的第i个元素放入featList中
uniqueVals = set(featList) #这里featList中保存的是数据集第i个特征的值set()去掉重复的值
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value) #根据第i个特征的值划分数据集得到子数据集
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet) #计算条件熵H(D/A)
infoGain = baseEntropy - newEntropy #信息增益
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeatrue = i #最大信息增益的特征
return bestFeatrue
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items,\
key=operator.itemgetter(1),reverse=True) #python3中 iteritems 变成了items
return sortedClassCount[0][0] #返回出现次数最多的分类名称
def creatTree(dataSet,labels):
"""创建树"""
classList = [example[-1] for example in dataSet] #存放类别标签
if classList.count(classList[0]) == len(classList): #判断全部类别是否相同
return classList[0]
if len(dataSet[0]) == 1: #分类完所有特征仍然不能完全分类
return majorityCnt(classList) #返回出现次数最多的类别标签
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat] #为啥要用数据分类标签???
myTree = {bestFeatLabel:{}}
del(labels[bestFeat]) #删除第bestFeat类标签
featValues = [example[bestFeat] for example in dataSet] #获取最优特征对应的值
uniqueVals = set(featValues) #去掉重复的值
for value in uniqueVals:
subLabeis = labels[:] #复制标签
myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet,bestFeat,value),subLabeis)
#这里myTree是一个字典嵌套的字典,myTree[bestFeatLabel][value]指键bestFeatLabel下面的键value对应的值
return myTree
代码最后一行 myTree[bestFeatLabel][value]对字典的运用,花了好长时间才理解,是一个字典的多层嵌套。下面举一个例子
def creatDataSet():
dataSet = [
[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']
]
labels = ['no surfacing','flippers']
return dataSet,labels
mydata,labels = creatDataSet()
mytree = creatTree(mydata,labels)
print(mytree)
mytree['no surfacing'][1]['flippers'][0]= 'yes'
print(mytree)
输出结果
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'yes', 1: 'yes'}}}}