import math
import matplotlib.pyplot as plt
D = [
['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否']
]
A = ['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']
def getEnt(D):
kMap = dict()
for dLine in D:
k = dLine[len(dLine) - 1]
kNum = kMap.get(k)
if kNum is None:
kMap[k] = 1
else:
kMap[k] = kNum + 1
dLen = len(D)
rs = 0
for kk in kMap:
pk = kMap[kk]/dLen
rs = rs + pk * math.log2(pk)
return -rs
def getGain(D,aIndex):
dMap = dict()
for dLine in D:
k = dLine[aIndex]
dChildren = dMap.get(k)
if dChildren is None:
dChildren = []
dMap[k] = dChildren
dChildren.append(dLine)
rs = 0
for key in dMap:
dChildren = dMap[key]
entx = getEnt(dChildren)
r = len(dChildren)/len(D) * entx
rs = rs + r
return getEnt(D) - rs
def getMaxtGainIndex(D):
i = 0
nowMaxIndex = 0
nowMaxGain = 0
while i < len(D[0]) - 1:
gainI = getGain(D,i)
print("第:" ,i , "列Gain为:" , gainI)
if gainI > nowMaxGain:
nowMaxGain = gainI
nowMaxIndex = i
i += 1
return nowMaxIndex
def sameCategory(D):
allFlag = True
nowJudge = None
for d in D:
if nowJudge is None:
nowJudge = d[len(d) -1]
else:
if nowJudge != d[len(d) -1]:
allFlag = False
break
return allFlag,nowJudge
aAStartVMap = dict()
def initAStartV(D):
if len(aAStartVMap) == 0:
for dLine in D:
for index,lable in enumerate(dLine):
aStart = aAStartVMap.get(index)
if aStart == None:
aStart = set()
aAStartVMap[index] = aStart
aStart.add(lable)
def getAStartV(index):
return aAStartVMap[index]
class Node:
name = "未命名节点"
desc = ""
children = []
def __init__(self):
self.children = []
def addChild(self, node):
self.children.append(node)
class Tree:
root = None
decisionNode = dict(boxstyle="round4", fc="0.5")
leafNode = dict(boxstyle="circle", fc="0.5")
arrow_args = dict(arrowstyle="<-")
step = 3
deep = 0
nowDeepIndex = 0
def __init__(self, root):
self.root = root
plt.xlim(0, 20)
plt.ylim(-18, 0)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
def drawLeaf(self, x1, y1, x2, y2, text, desc):
plt.annotate(text,
xy=(x1, y1),
xytext=(x2, y2),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.leafNode,
arrowprops=self.arrow_args)
plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)
def drawDecision(self, x1, y1, x2, y2, text, desc):
plt.annotate(text,
xy=(x1, y1),
xytext=(x2, y2),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.decisionNode,
arrowprops=self.arrow_args)
plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)
def drawRoot(self, text):
plt.annotate(text,
xy=(0, 0),
va='center',
ha='center',
xycoords="data",
textcoords='data',
bbox=self.decisionNode)
def drawTree(self):
self.draw0(self.root, 0, 0)
plt.show()
deepIndex = dict()
deep = 0
def draw0(self, node, x, y):
if(self.deepIndex.get(self.deep) is None):
self.deepIndex[self.deep] = 0
x2 = self.deepIndex[self.deep] * self.step
y2 = y - self.step
if len(node.children) > 0:
if len(node.desc) > 0:
self.drawDecision(x, y, x2, y2, node.name, node.desc)
self.deep += 1
for i, child in enumerate(node.children):
self.draw0(child, x2, y2)
self.deep -= 1
else:
self.drawRoot(node.name)
for i, child in enumerate(node.children):
self.draw0(child, 0, 0)
else:
self.drawLeaf(x, y, x2, y2, node.name, node.desc)
self.deepIndex[self.deep] = self.deepIndex[self.deep] + 1
def createTree(D,node,aStart):
newNode = Node()
if(aStart is not None):
newNode.desc = aStart
if node is None:
node = newNode
else:
node.addChild(newNode)
if len(D) == 0:
newNode.name = "是"
return
allFlag,nowJudge = sameCategory(D)
if allFlag:
newNode.name = nowJudge
return
index = getMaxtGainIndex(D)
print("信息增益最高的列index:" , index, "newNode name:",A[index])
newNode.name = A[index]
aStartVMap = dict()
for dLine in D:
dv = aStartVMap.get(dLine[index])
if dv is None:
dv = []
aStartVMap[dLine[index]] = dv
dv.append(dLine)
allAStartV = getAStartV(index)
for aa in allAStartV:
r = aStartVMap.get(aa)
if r is None:
aStartVMap[aa] = []
for aStart in aStartVMap:
createTree(aStartVMap[aStart],newNode,aStart)
return node
initAStartV(D)
root = createTree(D,None,None)
treex = Tree(root)
treex.drawTree()
