机器学习实战-决策树算法

本篇决策树算法是依据ID3算法来的,所以在看之间建议先了解ID3算法:https://blog.csdn.net/qq_27396861/article/details/88226296

文章目录

        • 一、构建决策树
        • 二、plot构建树

案例,按照属性来分辨海洋生物:

图一:
机器学习实战-决策树算法_第1张图片
图二:
机器学习实战-决策树算法_第2张图片

一、构建决策树

实例:

# coding: utf-8

from math import log
import operator

def createDataSet():
	dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']]
	labels = ['no surfacing', 'flippers']
	return dataSet, labels

def calcShannonEnt(dataSet):
	''' 计算给定总的数据集的香农熵 '''
	numEntries = len( dataSet ) # 集合里元素的数量
	labelCounts = {}
	
	# 为所有可能的分类创建次数字典
	for featVec in dataSet:
		currentLabel = featVec[-1] # 倒数第一个元素
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel] = 0
		labelCounts[currentLabel] += 1
		
	shannonEnt = 0.0
	for key in labelCounts:
		# 将当前样本的数量除以总的样本数量
		prob = float(labelCounts[key]) / numEntries
		# 以2为底求对数,然后求和
		shannonEnt -= prob * log(prob, 2)
		
	return shannonEnt
	
def splitDataSet(dataSet, axis, value):
	''' 按照给定特征划分数据集, 待划分的数据集 划分数据集的特征(dataSet里特征的下标) 特征的返回值 '''
	''' 如果axis下标是0,那么reduceFeatVec就是后两个,如果是1,那么就是第一个和第三个。也就是除了axisd的下标'''
	retDataSet = []
	for featVec in dataSet:
		if featVec[axis] == value:
			reduceFeatVec = featVec[:axis]
			reduceFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reduceFeatVec)
	return retDataSet
	
def chooseBestFeatureToSplit(dataSet):
	''' 选择最好的数据集划分方式 '''
	numFeatures = len(dataSet[0]) - 1 # 每个元素的特征个数
	# 香农熵
	baseEntropy = calcShannonEnt(dataSet) # 求取数据集合的香农熵
	
	bestInfoGain = 0.0	# 最好的熵
	bestFeature = -1	# 最好的特征
	
	for i in range(numFeatures):
		featList = [example[i] for example in dataSet]
		uniqueVals = set(featList)
		newEntropy = 0.0 
		
		for value in uniqueVals:
			subDataSet = splitDataSet(dataSet, i, value)
			prob = len(subDataSet) / float(len(dataSet)) # 符合的数量 / 总的数量
			# 条件熵
			newEntropy += prob * calcShannonEnt(subDataSet) # 划分完之后的信息熵,相加
			
		#信息增益
		infoGain = baseEntropy - newEntropy
		if(infoGain > bestInfoGain):
			bestInfoGain = infoGain
			bestFeature = i
			
	# 返回0是不浮出水面也能生活,1是否有脚蹼
	return bestFeature	
	
def majorityCnt(classList):
	''' 找出数量最多的分类 '''
	# 分类字典
	classCount = {}
	for vote in classList:
		if vote not in classCount.keys():
			classCount[vote] = 0
		classCount[vote] += 1
		
	# 以第二列的数据排序
	sortedClassCount = sorted(classCount.iteritems(), \
		key = operator.itemgetter(1), reverse = True)
		
	return sortedClassCount[0][0]
	
def createTree(dataSet, labels):
	''' 创建树的函数代码 '''
	classList = [example[-1] for example in dataSet] # 提取所有的类
	print "classList = ", classList
	
	# 数据集都是同一类的情况
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	
	# 如果数据集只有一个特征的情况
	if len(dataSet[0]) == 1:
		return majorityCnt(classList) # 那就按大多数的分类
		
	bestFeat = chooseBestFeatureToSplit( dataSet ) # 最好的特征
	bestFeatLabel = labels[bestFeat] # 最好的分类
	
	myTree = {bestFeatLabel:{}}
	
	# 递归建树
	del( labels[bestFeat] )
	featValue = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValue) # 最好的特征集合
	
	for value in uniqueVals:
		subLabels = labels[:] # 去掉前面标签之后剩下的标签
		myTree[bestFeatLabel][value] = createTree(splitDataSet\
			(dataSet, bestFeat, value), subLabels)
			
	return myTree
		
def main():
	dataSet, labels = createDataSet()
	myTree = createTree(dataSet, labels)
	print myTree
	
if __name__=="__main__":
	main()

结果与上述图二一致:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

二、plot构建树

由于代码的思想是将x轴和y轴平均分为几等分,且长度都是都是0到1,totalW记录叶子结点的个数,那么 1/totalW 正好是每个叶子结点的宽度。

以下代码是构建决策树视图的,难点在于两处地方:
(1)初始化x轴坐标的时候:
由于要使每次加1/totalW 正好是每个跟节点的增加的距离是在叶子节点之间,故要设置x初始位置要减去 0.5 / totalW,以至于后面每加一个1/totalW ,刚好都是在中间。

plotTree.xOff = -0.5 / plotTree.totalW 

(2)计算跟节点x轴坐标的时候
这个需要分步理解:
X坐标=节点的x偏移量 + 叶节点数距离

所有该节点下子叶子节点的距离:numLeafs / plotTree.totalW
但是坐标在叶子节点的中心:numLeafs / 2 / plotTree.totalW
又因为xOff初始坐标点在原点的左边:numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW,这是偏移量
那么x = numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW + plotTree.xOff

plotTree.xOff + (1.0 + float(numLeafs) ) / 2.0 / plotTree.totalW

代码如下:

# coding: utf-8

import matplotlib.pyplot as plt
import sys

# 定义文本框和箭头格式
decisionNode = dict( boxstyle = "sawtooth", fc = "1.0" )
leafNode = dict( boxstyle = "round4", fc = "0.8" )
arrow_args = dict( arrowstyle = "<-" )

# 绘制带箭头的注释
def plotNode( nodeTxt, centerPt, parentPt, nodeType ):
	createPlot.ax1.annotate( nodeTxt, xy = parentPt, xycoords = 'axes fraction',\
		xytext = centerPt, textcoords = 'axes fraction', va = 'center',\
		ha = 'center', bbox = nodeType, arrowprops = arrow_args )

'''		
def createPlot():
	fig = plt.figure( 1, facecolor = 'white' )
	fig.clf()
	
	createPlot.ax1 = plt.subplot( 111, frameon = False )
	
	plotNode( "decision node", (0.5, 0.1), (0.1, 0.5), decisionNode )
	plotNode( "leaf node", (0.8, 0.1), (0.3, 0.8), leafNode )
	plt.show()
'''
	
def getNumLeafs( myTree ):
	''' 获取叶子节点数 '''
	numLeafs = 0
	firstStr = myTree.keys()[0] # 头结点
	secondDict = myTree[ firstStr ] # 取出头结点的的字典
	
	# 测试节点的数据类型是否为字典
	for key in secondDict.keys():
		if type( secondDict[key] ).__name__=='dict':
			numLeafs += getNumLeafs( secondDict[key] )
		else:
			numLeafs += 1
	return numLeafs
	
def getTreeDepth( myTree ):
	''' 求树的深度 '''
	maxDepth = 0
	firstStr = myTree.keys()[0] # 头结点
	secondDict = myTree[ firstStr ] 
	
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			thisDepth = 1 + getTreeDepth( secondDict[key] )
		else:
			thisDepth = 1
		
		if thisDepth > maxDepth:
			maxDepth = thisDepth
	
	return maxDepth
	
def retrieveTree(i):
	listOfTrees = [ {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},\
					{'no surfacing': {0: 'no', 1: {'flippers': {0: \
					{'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
	return listOfTrees[i]
	
def plotMidText( cntrPt, parentPt, txtString ):
	''' 在父子节点之间填充文本信息 '''
	xMid = ( parentPt[0] - cntrPt[0] ) / 2.0 + cntrPt[0]
	yMid = ( parentPt[1] - cntrPt[1] ) / 2.0 + cntrPt[1]
	
	createPlot.ax1.text( xMid, yMid, txtString )
	
def plotTree( myTree, parentPt, nodeTxt ):
	''' parentPt 根节点坐标 '''

	numLeafs = getNumLeafs( myTree ) # 子节点数量
	depth = getTreeDepth( myTree )	# 深度
	firstStr = myTree.keys()[0] # 根节点的key
	
	'''
	X坐标=节点的x偏移量 + 叶节点数距离
	
	所有该节点下子叶子节点的距离:numLeafs / plotTree.totalW
	但是坐标在叶子节点的中心:numLeafs / 2 / plotTree.totalW
	又因为xOff初始坐标点在原点的左边:numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW ,这是偏移量
	那么x = numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW + plotTree.xOff 
	'''
	
	# 根节点坐标
	# 叶子节点距离
	x = plotTree.xOff + (1.0 + float(numLeafs) ) / 2.0 / plotTree.totalW
	#x = plotTree.xOff + (float(numLeafs) ) / 2.0 / plotTree.totalW
	y = plotTree.yOff
	print "x = %f, y = %f" % (x, y)
	cntrPt = ( x, y ) 
		
	# 标记子节点属性值
	plotMidText( cntrPt, parentPt, nodeTxt )
	
	plotNode( firstStr, cntrPt, parentPt, decisionNode )
	
	secondDict = myTree[firstStr]
	plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
	
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree( secondDict[key], cntrPt, str(key) )
		else:
			plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
			plotNode( secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode )
			plotMidText( (plotTree.xOff, plotTree.yOff), cntrPt, str(key) )
	
	plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
	
def createPlot( inTree ):
	fig = plt.figure( 1, facecolor='white' )
	fig.clf()
	
	axprops = dict( xticks = [], yticks = [] )
	createPlot.ax1 = plt.subplot( 111, frameon = False, **axprops )
	
	# 给plotTree函数建立属性
	plotTree.totalW = float( getNumLeafs( inTree ) ) # 宽度是叶子节点数
	plotTree.totalD = float( getTreeDepth( inTree ) ) # 高度是深度
	plotTree.xOff = -0.5 / plotTree.totalW # 方便后面加上 1.0 / plotTree.totalW 后位置刚好在中间
	plotTree.yOff = 1.0
	
	plotTree( inTree, (0.5, 1.0), '' )
	plt.show()
			
def main():
	# createPlot()
	myTree = retrieveTree(1)
	createPlot( myTree )
	
	
if __name__=="__main__":
	main()

结果如下:
机器学习实战-决策树算法_第3张图片

你可能感兴趣的:(机器学习,机器学习,决策树,ID3)