import math
import matplotlib.pyplot as plt
D = [
['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否']
]
A = ['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']
############################### 信息熵相关算法#####################################
# 当前样本集合D中第k类样本所占比例为pk(k=1,2,3,…,|y|)
# 计算A的信息熵,以数据最后一列为分类
def getEnt(D):
# 获取一个类型k->出现次数的map
kMap = dict()
for dLine in D:
# 获取分类值k
k = dLine[len(dLine) - 1]
# 获取当前k出现的次数
kNum = kMap.get(k)
if kNum is None:
kMap[k] = 1
else:
kMap[k] = kNum + 1
# 遍历map
dLen = len(D)
rs = 0
for kk in kMap:
pk = kMap[kk]/dLen
rs = rs + pk * math.log2(pk)
return -rs
# 求信息增益,aIndex为属性列号
def getGain(D,aIndex):
dMap = dict()
for dLine in D:
# 获取属性
k = dLine[aIndex]
# 属性所属的数组
dChildren = dMap.get(k)
if dChildren is None:
dChildren = []
dMap[k] = dChildren
dChildren.append(dLine)
rs = 0
for key in dMap:
dChildren = dMap[key]
entx = getEnt(dChildren)
r = len(dChildren)/len(D) * entx
rs = rs + r
return getEnt(D) - rs
# 求信息增益最大的属性列号
def getMaxtGainIndex(D):
i = 0
nowMaxIndex = 0
nowMaxGain = 0
while i < len(D[0]) - 1:
gainI = getGain(D,i)
print("第:" ,i , "列Gain为:" , gainI)
if gainI > nowMaxGain:
nowMaxGain = gainI
nowMaxIndex = i
i += 1
return nowMaxIndex
############################### 辅助算法 #####################################
# 判断D的集合是否是判定同一类型,即全是好瓜或全是坏瓜,返回判定结果以及好坏(为False是第二个参数无效)
def sameCategory(D):
allFlag = True
nowJudge = None
for d in D:
# 取最后一列为为好坏瓜
if nowJudge is None:
nowJudge = d[len(d) -1]
else:
# 只要有一个不等,就继续拆分
if nowJudge != d[len(d) -1]:
allFlag = False
break
return allFlag,nowJudge
aAStartVMap = dict()
def initAStartV(D):
if len(aAStartVMap) == 0:
for dLine in D:
for index,lable in enumerate(dLine):
aStart = aAStartVMap.get(index)
if aStart == None:
aStart = set()
aAStartVMap[index] = aStart
aStart.add(lable)
# 获取指定的D,某一个的每一个属性值的集合AStartV
def getAStartV(index):
return aAStartVMap[index]
############################### 节点类 #####################################
class Node:
name = "未命名节点"
# 线描述,没有的是根节点
desc = ""
# 子节点,长度为0的是叶节点
children = []
def __init__(self):
self.children = []
def addChild(self, node):
self.children.append(node)
############################### 画树类 #####################################
class Tree:
root = None
# 定义决策节点以及叶子节点属性:boxstyle表示文本框类型,sawtooth:锯齿形;circle圆圈,fc表示边框线粗细
decisionNode = dict(boxstyle="round4", fc="0.5")
leafNode = dict(boxstyle="circle", fc="0.5")
# 定义箭头属性
arrow_args = dict(arrowstyle="<-")
# 步长,每个节点的横线和纵向距离
step = 3
# 当前深度
deep = 0
# 当前深度的个数
nowDeepIndex = 0
def __init__(self, root):
self.root = root
# 设定坐标范围
plt.xlim(0, 20)
plt.ylim(-18, 0)
# 设定中文支持
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
# 绘制叶节点
# x1,y1 箭头起始点坐标
# x2,y2 箭头目标点(文字点坐标)
# text 节点文字
# desc 线文字
def drawLeaf(self, x1, y1, x2, y2, text, desc):
# 绘制节点以及箭头
plt.annotate(text,
xy=(x1, y1),
xytext=(x2, y2),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.leafNode,
arrowprops=self.arrow_args)
# 绘制线上的文字
plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)
# 绘制决策节点
def drawDecision(self, x1, y1, x2, y2, text, desc):
# 绘制节点以及箭头
plt.annotate(text,
xy=(x1, y1),
xytext=(x2, y2),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.decisionNode,
arrowprops=self.arrow_args)
# 绘制线上的文字
plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)
# 绘制根节点(特殊决策节点)
def drawRoot(self, text):
# 绘制节点以及箭头
plt.annotate(text,
xy=(0, 0),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.decisionNode)
def drawTree(self):
self.draw0(self.root, 0, 0)
plt.show()
# 当前深度和这个深度的当前节点数量的映射
deepIndex = dict()
# 当前深度
deep = 0
# xy是父节点的坐标
def draw0(self, node, x, y):
# 如果当前深度节点数量没有,则置为0
if(self.deepIndex.get(self.deep) is None):
self.deepIndex[self.deep] = 0
# 注意因为是基于当前节点数量排列所有节点,故都基于0点排列
x2 = self.deepIndex[self.deep] * self.step
y2 = y - self.step
if len(node.children) > 0:
if len(node.desc) > 0:
self.drawDecision(x, y, x2, y2, node.name, node.desc)
self.deep += 1
for i, child in enumerate(node.children):
self.draw0(child, x2, y2)
self.deep -= 1
else:
self.drawRoot(node.name)
for i, child in enumerate(node.children):
self.draw0(child, 0, 0)
else:
self.drawLeaf(x, y, x2, y2, node.name, node.desc)
# 当前深度节点数++
self.deepIndex[self.deep] = self.deepIndex[self.deep] + 1
############################### 建立决策树 #####################################
def createTree(D,node,aStart):
# 创建节点
newNode = Node()
if(aStart is not None):
newNode.desc = aStart
if node is None:
node = newNode
else:
node.addChild(newNode)
# 如果这个子集是空集.那么标记为叶节点,返回(参考色泽-浅白)
if len(D) == 0:
# 先直接写成是吧,有点问题
newNode.name = "是"
return
#如果D这个子集中,所有的判定都是好瓜或者是坏瓜,没有必要继续下去了,直接设定为叶节点
allFlag,nowJudge = sameCategory(D)
# 判断完了,全等,则直接建立为叶节点返回
if allFlag:
newNode.name = nowJudge
return
# 获取信息增益最高的列index,创建节点,按照这个属性拆分数据为D1,D2,D3...Dn
index = getMaxtGainIndex(D)
print("信息增益最高的列index:" , index, "newNode name:",A[index])
newNode.name = A[index]
# 一个属性->这个属性的子集的map,将原来的D按照属性拆分为几个子集,这个map的key就是下层个节点的desc
aStartVMap = dict()
# 不能直接以D的结果集取找所有属性,会导致属性丢失(此例中会在色泽中丢失)浅白
for dLine in D:
dv = aStartVMap.get(dLine[index])
if dv is None:
dv = []
aStartVMap[dLine[index]] = dv
dv.append(dLine)
# 获取所有属性,然后比对一下有没有缺的,不上
allAStartV = getAStartV(index)
for aa in allAStartV:
r = aStartVMap.get(aa)
if r is None:
aStartVMap[aa] = []
# 先获取所有的属性,然后以属性遍历
for aStart in aStartVMap:
createTree(aStartVMap[aStart],newNode,aStart)
return node
initAStartV(D)
root = createTree(D,None,None)
treex = Tree(root)
treex.drawTree()