本节使用数据依旧是之前生成的三种球类数据,刚进入这篇文章的小伙伴可以回头看下。链接如下:
机器学习入门之k近邻算法_俺从头开始的博客-CSDN博客
百度百科讲决策树:“决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。”
本质上来讲,决策树还是一个分类模型,所以它的中心工作是利用一种度量手段来区分各个类别的数据。
一位二十世纪的天才——克劳德·香农,提出了一种名叫“熵”的度量标准,至今依旧被广泛应用于信息领域。
熵公式:
其中,为分类的概率。
与信息熵一起诞生的产物。
计算公式:
一个系统越是有序,信息熵就越低,一个系统越是混乱,信息熵就越高,所以信息熵被认为是一个系统有序程度的度量。
主要就是敲代码实现的过程了,展示如下:
from math import log
import operator
import matplotlib.pyplot as plt
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 计算实例总数
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 键值是最后一列数值
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 将不存在的类别加入字典
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2) # 香农公式
return shannonEnt
# 划分数据集
def splitDataSet(dataSet, axis, value): # dataSet:待划分的数据集,axis:划分数据集特征,value:需要返回的特征值
retDataSet = [] # 防止破坏原始数据
for featVec in dataSet:
if featVec[axis] == value: # 符合要求
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec) # 抽取
return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) # 创建唯一的分类标签列表
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 计算每种方式的信息熵
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i # 计算最好的信息增益
return bestFeature
# 取最大值标签
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), \
key=operator.itemgetter(1), reverse=True) # 根据字典的值降序排列
return sortedClassCount[0][0]
# 创建树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] # classList只剩下一种值
if len(dataSet[0]) == 1: # dataSet中属性使用完毕,但没有分配完毕
return majorityCnt(classList) # 取数量最多作为分类
bestFeat = chooseBestFeatureToSplit(dataSet)
labels2 = labels.copy()
bestFeatLabel = labels2[bestFeat]
myTree = {bestFeatLabel: {}}
del(labels2[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels2[:] # 剩余属性列表
myTree[bestFeatLabel][value] = createTree(splitDataSet \
(dataSet, bestFeat, value), subLabels)
return myTree
#导入数据
def sentDataSet(filename):
with open(filename, 'r', encoding='utf-8') as file:
arrayOLines = file.readlines() #列表型
numberOfLines = len(arrayOLines)
dataSet = numpy.zeros((numberOfLines, 4))
index = 0
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t\t')
dataSet[index, :] = listFromLine[0: 5]
labels = ['圆周长', '重量', '材料', '花纹']
return dataSet, labels
myDat, labels = sentDataSet('./data1')
myTree = createTree(myDat, labels)
print(myTree)
结果如下:
数据集采用了上一节中已生成的数据 ,所以还是分类球类的问题。
主要遇到的问题是,连续的数据没有进行处理,以至于每一组数据都成了一个单独的类别。连续数组的处理将在下一节提到。
导入数据主要就是将之前生成的数据集导入进来,方便后续操作。这一块代码可能有些简陋,毕竟博主水平还有限,不像之前的代码可以照着书敲【手动狗头】。不过嘛,代码毕竟是调通了,皆大欢喜,放心食用。
当然,这个问题也不能一直放着。于是乎,博主决定去掉数据集里的连续型变量,只用后两列数据进行分类,结果如下:
至于代码嘛,倒也不用大改,切片取出数据时修改一下参数即可。
如下:
# 导入数据
def sentDataSet(filename):
with open(filename, 'r', encoding='utf-8') as file:
arrayOLines = file.readlines() # 列表型
dataSet = []
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t\t')
dataSet.append(listFromLine[2:5]) # 修改切片
labels = ['材料', '花纹'] # 修改这里
return dataSet, labels
修改部分已注释。
这一块是希望能过够将分类好的树可视化的输出,也就是直观的看到这棵树。
代码如下:
# 用Matplotlib注解绘制树形图
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="square", fc="0.8") # boxstyle文本框样式、fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8") # 叶子节点
arrow_args = dict(arrowstyle="<-") # 定义箭头
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 此函数执行绘制功能
# createPlot.ax1是表示: ax1是函数createPlot的一个属性
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,
textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs = 0 # 初始化
firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)
secondDict = myTree[firstStr] # 获得value值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典
numLeafs += getNumLeafs(secondDict[key]) # 递归调用函数
else:
numLeafs += 1
return numLeafs
# 获取树的深度
def getTreeDepth(myTree):
maxDepth = 0 # 初始化
firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)
secondDict = myTree[firstStr] # 获得value值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典
thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 画树
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 获取树高
depth = getTreeDepth(myTree) # 获取树深度
firstStr = list(myTree.keys())[0] # 这个节点的文本标签
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff) # plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置
plotMidText(cntrPt, parentPt, nodeTxt) # 标记子节点属性
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 减少y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 绘制决策树
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 创建一个新图形
fig.clf() # 清空绘图区
font = {'family': 'MicroSoft YaHei'}
plt.rc("font", **font)
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
Createplot = createPlot(myTree)
运行结果:
材料这个分类分的稀烂,但这主要是因为这个类别本身就不适合分类。但由于去掉了前两列的数据,防止生成树过于单调,我最后还是决定将它加上。
本身就是字体的选择,但需要你的电脑上有这个可以编译的字体,这里推荐使用
font = {'family': 'SimHei'},常规电脑应该还是没有问题的。
下一节将会对树进行优化,也有连续性变量的处理方式,欢迎追订哦!