参考文章(关于用matplotlib画决策树):
需要用到的模块:math,matplotlib.pyplot,operator
生成树
画图
'''
ID3算法
当前仅适用于离散型数据
'''
from math import log
import operator
## 信息增益:计算信息熵
def info_shan(dataset):
label_counts = {}
num_data = len(dataset)
# 计算每个分类出现的频率
for vet in dataset:
current_label = vet[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
shan = 0.0
for key in label_counts:
probility = float(label_counts[key])/num_data
shan -= probility * log(probility,2)
return shan
## 划分数据集(第二个参数是特征在数据集的位置)
def splitdata(dataset,feature,value):
sub_dataset = []
# 将数据集按指定属性的值划分
for vector in dataset:
if vector[feature] == value :
# 提取符合要求的属性且不包含该指定属性
fea_vector = vector[:feature]
fea_vector.extend(vector[feature+1:])
sub_dataset.append(fea_vector)
return sub_dataset
## 选择划分数据集的最佳属性(信息增益)
def best_feature(dataset):
# 默认最后一列是标签
num_feature = len(dataset[0])-1
shan = info_shan(dataset)
best_info_gain = 0.0
best_feature = -1
# 遍历数据集中所有特征
for i in range(num_feature):
fea_list = [x[i] for x in dataset]
# 建立集合,得到不重复的值的集合
fea_vals = set(fea_list)
fea_shan = 0.0
# 根据值的不同划分子集
for value in fea_vals:
sub_dataset = splitdata(dataset,i,value)
probility = len(sub_dataset)/float(len(dataset))
fea_shan += probility * info_shan(sub_dataset)
info_gain = shan - fea_shan
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
## 确定叶子节点的类别(多数表决法)
def majority_vote(classlist):
class_count={}
# 遍历每个类别,得到频数
for vote in classlist:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote]+=1
# 找到最高频数的类别
class_label = sorted(class_count.items(),key=operator.itemgetter(1),\
reverse=True)
return class_label[0][0]
# 生成树
def create_tree(dataset,labels):
classlist = [x[-1] for x in dataset]
# 若类别相同,则停止划分
if classlist.count(classlist[0]) == len(classlist):
return classlist[0]
# 若所有特征均被遍历,返回出现次数最多的类别
if len(dataset[0]) == 1:
return majority_vote(classlist)
# 选择最优属性
best_fea = best_feature(dataset)
best_fea_label = labels[best_fea]
mytree = {best_fea_label:{}}
del(labels[best_fea])
fea_val = [x[best_fea] for x in dataset]
vals = set(fea_val)
for val in vals:
sublabels = labels[:]
mytree[best_fea_label][val] = create_tree(splitdata(dataset,best_fea,val),\
sublabels)
return mytree
## 分类
def classify(input_tree,feature_label,testvec):
keys = list(input_tree.keys())
firststr = keys[0]
second_dict = input_tree[firststr]
# 确定进行判断的属性
fea_index = feature_label.index(firststr)
for key in second_dict.keys():
# 确定测试集在该属性下的值,遍历所有节点进行匹配
if testvec[fea_index] == key:
# 若匹配到的节点有下一层树,则递归,根据下一个属性分类
if type(second_dict[key]).__name__ == 'dict':
classlabel = classify(second_dict[key],feature_label,testvec)
else:
# 若匹配到的节点已经是叶节点,则确定分类
classlabel = second_dict[key]
return classlabel
## 存储构造好的树(便于下一次直接用树分类)
def store_tree(tree,filename):
import pickle
# python3不接受二进制文件,需要用二进制写入模式(用with语句更简洁)
with open(filename,'wb') as filewrite:
pickle.dump(tree,filewrite)
## 获取存储的树
def grab_tree(filename):
import pickle
with open(filename,'rb') as fileread:
return pickle.load(fileread)
import matplotlib.pyplot as plt
# 决策节点的样式:boxstyle是文本框形状,sawtooth是锯齿型,fc是边框粗细
decision_node = dict(boxstyle = 'sawtooth',fc='0.8')
# 叶节点的样式
leaf_node = dict(boxstyle = 'round4',fc='0.8')
# 箭头属性
arrow_args = dict(arrowstyle = '<-')
## 绘制带箭头的注解
'''
nodetxt-文本
centerpt-文本中心点
dotpt-箭头指向文本的点
nodetype判断是叶节点还是决策节点的样式
'''
def plot_node(nodetxt,centerpt,dotpt,nodetype):
# annotate()函数用于标注文字
'''
nodetxt-标注内容
xy-箭头指向的点的坐标
xycoords-指向的点的坐标属性(以子绘图区左下角为参考,单位是百分比)
xytext-标注内容的坐标
textcoords-文本的坐标属性
va/ha-点的位置(va:top, bottom, center, baseline;ha:right,center,left)
bbox-内容增加外框
arrowprops-箭头参数(字典形式)
'''
create_plot.ax1.annotate(nodetxt,xy=dotpt,xycoords='axes fraction',\
xytext=centerpt,textcoords='axes fraction',\
va='center',ha='center',bbox=nodetype,
arrowprops=arrow_args)
# 判断叶节点的个数和决策树层数
def get_leafnum_depth(mytree):
leafnum = 0
depth_all = 0
keys = list(mytree.keys())
firststr = keys[0]
second_dict = mytree[firststr]
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
leafnum0,depth0 = get_leafnum_depth(second_dict[key])
leafnum += leafnum0
depth = 1+ depth0
else:
leafnum += 1
depth = 1
if (depth > depth_all): depth_all = depth
return leafnum,depth_all
# 节点之间的线的中间补充文本信息
def plot_midtext(centerpt,dotpt,stringtxt):
xmid = (dotpt[0] - centerpt[0])/2.0 + centerpt[0]
ymid = (dotpt[1] - centerpt[1])/2.0 + centerpt[1]
# 在ax1图上标注信息(x-位置,y-位置,文本)
create_plot.ax1.text(xmid,ymid,stringtxt)
## 树的绘制
'''将叶子节点数作为份数平均切分整个x轴,将树的层数作为份数平均切分y轴长度'''
def plot_tree(mytree,dotpt,nodetxt):
leafnum,depth = get_leafnum_depth(mytree)
keys = list(mytree.keys())
firststr = keys[0]
# 确定节点的位置(xoff(上一个节点的位置)+偏移量,yoff不变)
'''
确定节点位置时每次需确定当前层有几个叶子节点,
这层所有叶子节点所占的总距离即为float(leafnum)/plotTree.totalw
而当前节点的位置即为其所有叶子节点所占距离的中间即为float(leafnum)/2.0/plotTree.totalw
由于开始plotTree.xoff赋值左移了半个距离,因此还需加上1/2/plotTree.totalw
'''
centerpt = (plot_tree.xoff + (1.0 + float(leafnum))/2.0/plot_tree.totalw,\
plot_tree.yoff)
plot_midtext(centerpt,dotpt,nodetxt)
plot_node(firststr,centerpt,dotpt,decision_node)
second_dict = mytree[firststr]
# 确定同层下一个节点y轴位置(偏移量为1/层数)
plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totald
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
# 若该节点是父节点,则递归自身
plot_tree(second_dict[key],centerpt,str(key))
else:
# 若是叶节点,直接绘制
plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalw
plot_node(second_dict[key],(plot_tree.xoff,plot_tree.yoff),\
centerpt,leaf_node)
plot_midtext((plot_tree.xoff,plot_tree.yoff),centerpt,str(key))
plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totald
# 生成图
def create_plot(intree):
# 新建一个绘图窗口,facecolor-区域背景色
figure1 = plt.figure(num=1,facecolor = 'white')
# 清空绘图区
figure1.clf()
# 设置坐标轴刻度标签
axprops = dict(xticks=[],yticks=[])
# 创建该函数的属性(ax1)
'''创建一个新的子图,图绘制在第一块,frameno-是否绘制图像边框'''
create_plot.ax1 = plt.subplot(111,frameon = False,**axprops)
# totalw-叶节点数;totald-树的层数
leafnum,depth = get_leafnum_depth(intree)
plot_tree.totalw = float(leafnum)
plot_tree.totald = float(depth)
# xoff,yoff表示第一个节点位置
'''1/叶节点个数=两个节点相隔的距离,
*(-1/2)表示初始位置向左偏移0.5个距离(为了图形好看)'''
plot_tree.xoff = (1/plot_tree.totalw)*(-1/2)
plot_tree.yoff = 1
# 构建决策树,第一个节点在(0.5,1.0)
plot_tree(intree,(0.5,1.0),'')
plt.show()