本文摘自“机器学习实战”中案例,在此对其进行了代码更新与简单注释。感兴趣者可回复资源需求!
现有一份海洋生物数据表,如下图所示:
不浮出水面是否可以生存 | 是否有脚蹼 | 是否鱼类 | |
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
5 | 否 | 是 | 否 |
因为没有大量样本存储于文档中,故在次没有将文档样本内的数据转成可以处理的数据形式,而是直接简单创造,如下所示:
将特征值“是”表示为1,“否”表示为0;标签中用“yes”、“no”表示。
def createDataSet():
dataset=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=['no surfing','flippers'] #特征名称
return dataset,labels
划分数据集的最大原则:将无序的数据变得更加有序。使用信息论度化信息量是将数据变得更加有序的方法之一。划分数据集前后信息发生的变化称为信息增益,获得信息增益最高的特征作为每次划分的依据。在计算每种划分方式的信息增益之前,需要计算相应数据集的香农熵。
以上信息有想深入了解者,可自行查询。不甚了解不影响解题。
(1)计算给定数据集的香农熵
#计算给定数据集的香农熵
from math import log
def calcShannongEnt(dataset):
numEntries=len(dataset)
labelCount={} #统计所有类标签的发生频率
for featVec in dataset:
currentLabel=featVec[-1]
if currentLabel not in labelCount.keys():
labelCount[currentLabel]=0
labelCount[currentLabel]+=1
shannongEnt=0.0 #该数据集的香农熵
for key in labelCount:
prob=float(labelCount[key])/numEntries
shannongEnt+=-prob*log(prob,2)
return shannongEnt
(2)按照给定的特征划分数据集
'''
例如将dataset=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
根据第0个特征(axis=0)“no surfing”,以及特征值为1(value=1)进行划分数据集,结果为
[[1, 'yes'],
[1, 'yes'],
[0, 'no']]
'''
#按照给定的特征划分数据集
def spliteDataSet(dataset,axis,value): #axis为划分数据集的特征,value为划分数据集的特征值
retDataSet=[]
for featVec in dataset:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
#[1,2].extend([3,4])结果为[1,2,3,4]
#[1,2].append([3,4]结果为[1,2,[3,4]]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
(3)选择最好的数据集划分方式,即选取当前数据集中信息增益最高的特征
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataset):
# 条件:数据是由列表元素组成的列表,而且所有的列表元素具有相同的数据长度
numFeatures=len(dataset[0])-1
baseEntropy=calcShannongEnt(dataset)
bestinfogain=0.0 ; bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataset]
#.set()将列表转化为每个值都不相同的集合
uniqueVals=set(featList)
newEntropy=0.0
#计算每种划分方式的信息熵
for value in uniqueVals:
subdataset=spliteDataSet(dataset,i,value)
prob=len(subdataset)/float(len(dataset))
newEntropy+=prob*calcShannongEnt(subdataset)
infogain=baseEntropy-newEntropy #信息增益是熵的减少
if infogain>bestinfogain:
bestinfogain=infogain
bestFeature=i
return bestFeature
#返回出现次数最多的分类名称
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=lambda x:x[1],reverse=True)
return sortedClassCount[0][0]
#创建决策树
def createTree(dataset,labels):
classList=[example[-1] for example in dataset] #包含了数据集中所有类别的标签
if classList.count(classList[0])==len(classList):
return classList[0]
# 使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,此时返回出现次数最多的类别
if len(dataset[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataset)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValue=[example[bestFeat] for example in dataset]
uniqueValues=set(featValue)
for value in uniqueValues:
sublabels=labels[:]
myTree[bestFeatLabel][value]=createTree(spliteDataSet(dataset,
bestFeat,
value),sublabels)
return myTree
'''
例如将dataset=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
构造成决策树,结果为
{'no surfing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
'''
#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
firstStr=list(inputTree.keys())[0] #决策树中的第一个特征名称
featIndex=featLabels.index(firstStr)
secondTree=inputTree[firstStr] #决策树中的第一个特征在所有特征中的索引
for key in secondTree:
if testVec[featIndex]==key:
if type(secondTree[key])==dict:
classLabel=classify(secondTree,featLabels,testVec)
else: classLabel=secondTree[key]
return classLabel
#为了避免每次分类时都需要重新创建决策树,将决策树用pickle模块存储
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb') #以二进制格式打开一个文件只用于写入
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename,'rb') #以二进制格式打开一个文件用于只读
return pickle.load(fr)
预测功能,大家可以写成函数形式
myData,labels=createDataSet()
global_labels=labels[:] #将labels赋值给全局变量global_labels,若global_labels=labels则为引用传递
mytree=createTree(myData,labels) #labels在创造决策树时被修改
storeTree(mytree,'classifierStorage.txt')
#现有一种海洋生物:不浮出水面不可以生存(0),没有脚蹼(0)。预测其是否为鱼类
mytree=grabTree('classifierStorage.txt')
print(mytree)
if classify(mytree,global_labels,[0,0])=='no':
print('不是鱼类')
else: print('是鱼类')