# -*- coding:utf-8 -*-
from math import log
import operator
import matplotlib.pyplot as plt
def calc(dataset):
numentries = len(dataSet)
labelscounts = {}
for featVec in dataset:
currentlabel = featVec[-1]
if currentlabel not in labelscounts.keys():
labelscounts[currentlabel] = 0
labelscounts[currentlabel] += 1
shannon = 0.0
for key in labelscounts:
prob = float(labelscounts[key])/numentries
shannon -= prob*log(prob,2)
return shannon
def creatDataSet():
dataSet = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
dataSet, labels = creatDataSet()
shannon = calc(dataSet)
print(shannon)
#按照给定的特征划分数据集
def splitdataset(dataSet, axis, value):
retdataset = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retdataset.append(reducedFeatVec)
return retdataset
dataSet, labels = creatDataSet()
retdataset = splitdataset(dataSet, 0, 1)
print(retdataset)
retdataset1 = splitdataset(dataSet, 1, 1)
print(retdataset1)
#定义选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numf = len(dataSet[0]) - 1
baseE = calc(dataSet)
bestI = 0.0
bestF = -1
for i in range(numf):
featlist = [example[i] for example in dataSet]
uniqueVals = set(featlist)
newE = 0.0
for value in uniqueVals:
subdataset = splitdataset(dataSet, i, value)
prob = len(subdataset)/float(len(dataSet))
newE += prob * calc(subdataset)
info = baseE - newE
if (info > bestI):
bestI = info
bestF = i
return bestF
dataSet, labels = creatDataSet()
a = chooseBestFeatureToSplit(dataSet)
print(a)
def maj(classlist):
classcount={}
for vote in classlist:
if vote not in classcount.keys():classcount[vote] = 0
classcount[vote] += 1
sortedclasscount = sorted(classcount.iteritems(), key = operator.itemgetter(1),reverse=True)
#reverse默认为升序,定义为True时,为降序
return sortedclasscount[0][0]
#创建树的函数代码
def createtree(dataSet, labels):
classlist = [example[-1] for example in dataSet]
if classlist.count(classlist[0]) == len(classlist):
return classlist[0]
if len(dataSet[0]) == 1:
return maj(classlist)
bestF = chooseBestFeatureToSplit(dataSet)
bestFlabel = labels[bestF]
myTree = {bestFlabel:{}}
del(labels[bestF])
featValues = [example[bestF] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
sublabels = labels[:]
myTree[bestFlabel][value] = createtree(splitdataset(dataSet, bestF, value), sublabels)
return myTree
dataSet, labels = creatDataSet()
c = createtree(dataSet, labels)
print(c)
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeText, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction',
xytext = centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor="white")
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode(U'a decision node', (0.5,0.1), (0.1,0.5), decisionNode)
plotNode(U'a leaf node', (0.8,0.1), (0.3,0.8), leafNode)
plt.show()
createPlot()
def getnumleafs(myTree):
numleafs = 0
firststr = myTree.keys()[0]
seconddict = myTree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__ =='dict':
numleafs += getnumleafs(seconddict[key])
else:
numleafs += 1
return numleafs
def gettreedepth(myTree):
maxdepth = 0
firststr = myTree.keys()[0]
seconddict = myTree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__=='dict':
thisdepth = 1 + gettreedepth(seconddict[key])
else:
thisdepth = 1
if thisdepth > maxdepth:
maxdepth = thisdepth
return maxdepth