代码参考了python实现ID3决策树分类算法_aoanng的博客-CSDN博客_id3算法python实现
有所精简
'''
function:ID3决策树生成算法
author:baomi
date: 2021/11/01
reference: https://blog.csdn.net/colourful_sky/article/details/82056125
'''
import math
def splitDataSet(dataSet, i, value):
'''
返回数据集dataSet中,去掉第i列属性值为value的实例后形成的新的数据集
'''
retDataSet = []
for x in dataSet:
if x[i] == value:
temp = x[:]
temp.pop(i)
retDataSet.append(temp)
return retDataSet
def calcEntropy(dataSet):
'''
计算一个数据集的熵
'''
labelDict = {} # 数据集的标签-该标签总个数
for x in dataSet:
label = x[-1]
if label not in labelDict.keys():
labelDict[label] = 0
labelDict[label] += 1
n = len(dataSet)
retEntropy = 0.0
for key in labelDict:
p = float(labelDict[key]) / n # 计算标签概率
retEntropy -= p * math.log(p, 2)
return retEntropy
def calcInfoGain(dataSet, i):
'''
计算对数据集dataSet,选定第i列特征时所获得的信息增益
'''
preEntropy = calcEntropy(dataSet)
postEntropy = 0.0
featureSet = set([x[i] for x in dataSet]) # 得到i列特征所有特征值的集合
for feature in featureSet: # 以feature为筛选条件,计算筛选后的数据集熵
subDataSet = splitDataSet(dataSet, i, feature)
subDataSetEntropy = calcEntropy(subDataSet)
p = len(subDataSet)/len(dataSet)
postEntropy += p*subDataSetEntropy
return preEntropy-postEntropy
def getMaxInfoGainNode(dataSet, featureNameList):
'''
featureNameList是dataSet中各特征名称
该函数返回两种结果:
熵为0时,返回标签,类型为str
熵不为0时,返回具有最大信息增益的特征的索引号,特征名,以及最大信息增益
'''
dataSetEntropy = calcEntropy(dataSet)
if dataSetEntropy == 0:
return dataSet[0][-1] # 数据集熵为0,说明标签都相同,直接将该标签返回
featureNum = len(featureNameList)
maxInfoGain = 0
maxInfoGainIndex = 0
for i in range(0, featureNum): #遍历所有特征,获得具有最大信息增益的特征索引号
infoGain = calcInfoGain(dataSet, i)
if infoGain > maxInfoGain:
maxInfoGain = infoGain
maxInfoGainIndex = i
return maxInfoGainIndex, featureNameList[maxInfoGainIndex], maxInfoGain
def createID3Tree(dataSet, featureNameList):
'''
该函数返回一个结点
如果dataSet熵为0,那么返回dataSet中类标签,此标签唯一
否则,返回一个字典,该字典的key为dataSet选出的最优特征名,value又为一个字典,
value字典的key为最优特征的特征值名,value字典的value又为一个字典.....
'''
maxInfoGainNode = getMaxInfoGainNode(dataSet, featureNameList)
if type(maxInfoGainNode) == str:
return maxInfoGainNode
nodeIndex, nodeName = maxInfoGainNode[0], maxInfoGainNode[1]
ret = {}
ret[nodeName] = {}
featureSet = set([x[nodeIndex] for x in dataSet])
for feature in featureSet:
subDataSet = splitDataSet(dataSet, nodeIndex, feature)
newFeatNameList = featureNameList[:]
newFeatNameList.pop(nodeIndex)
childTree = createID3Tree(subDataSet, newFeatNameList) #对以最大信息增益作为特征筛选后的子数据集进行递归调用
ret[nodeName][feature] = childTree
return ret
dataSet = [['青年', '否', '否', '一般', '拒绝'],
['青年', '否', '否', '好', '拒绝'],
['青年', '是', '否', '好', '同意'],
['青年', '是', '是', '一般', '同意'],
['青年', '否', '否', '一般', '拒绝'],
['中年', '否', '否', '一般', '拒绝'],
['中年', '否', '否', '好', '拒绝'],
['中年', '是', '是', '好', '同意'],
['中年', '否', '是', '非常好', '同意'],
['中年', '否', '是', '非常好', '同意'],
['老年', '否', '是', '非常好', '同意'],
['老年', '否', '是', '好', '同意'],
['老年', '是', '否', '好', '同意'],
['老年', '是', '否', '非常好', '同意'],
['老年', '否', '否', '一般', '拒绝'], ]
featureNameList = ['年龄', '有工作', '有房子', '信贷情况']
ID3Tree = createID3Tree(dataSet, featureNameList)
print(ID3Tree)
输出结果:
对结果用matplotlib绘图
代码来自Matplotlib绘制树形图_wancongconghao的博客-CSDN博客_matplotlib 树状图
#绘制树形图
import matplotlib.pyplot as plt
decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#获取树的叶子结点个数(确定图的宽度)
def get_leaf_num(tree):
leaf_num = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__=="dict":
leaf_num +=get_leaf_num(next_dict[key])
else:
leaf_num +=1
return leaf_num
#获取数的深度(确定图的高度)
def get_tree_depth(tree):
depth = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__ == "dict":
thisdepth = 1+ get_tree_depth(next_dict[key])
else:
thisdepth = 1
if thisdepth>depth: depth = thisdepth
return depth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = get_leaf_num(myTree)
depth = get_tree_depth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decision_node)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(get_leaf_num(inTree))
plotTree.totalD = float(get_tree_depth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
tree = {'有房子': {'否': {'有工作': {'否': '拒绝', '是': '同意'}}, '是': '同意'}}
createPlot(tree)
结果: