1. tree.py是ID3算法生成决策树的代码
2. treePlotter.py是将决策树绘制出来的代码
from math import log
import treePlotter
def calc_shannon_ent(dataset):
num_entries = len(dataset)
label_counts = {}
for feat_vec in dataset:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
shannon_ent -= prob * log(prob, 2)
return shannon_ent
def create_dataset():
dataset = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
def split_dataset(dataset, axis, value):
ret_dataset = []
for feat_vec in dataset:
if feat_vec[axis] == value:
reduced_feat_vec = feat_vec[:axis]
reduced_feat_vec.extend(feat_vec[axis + 1:])
ret_dataset.append(reduced_feat_vec)
return ret_dataset
def choose_best_feature_to_split(dataset):
"""
遍历整个数据集,循环计算香农熵和split_dataset()函数,找到最好的特征划分方式
:param dataset:
:return:
"""
num_features = len(dataset[0]) - 1
base_entropy = calc_shannon_ent(dataset)
best_info_gain = 0.0
best_feature = -1
for i in range(num_features):
feat_list = [example[i] for example in dataset]
unique_vals = set(feat_list)
new_entropy = 0.0
for value in unique_vals:
sub_dataset = split_dataset(dataset, i, value)
prob = len(sub_dataset) / float(len(dataset))
new_entropy += prob * calc_shannon_ent(sub_dataset)
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
def majority_cnt(class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.items(), key=lambda kv: kv[1], reverse=True)
return sorted_class_count[0][0]
def create_tree(dataset, labels):
class_list = [example[-1] for example in dataset]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
if len(dataset[0]) == 1:
return majority_cnt(class_list)
best_feat = choose_best_feature_to_split(dataset)
best_feat_label = labels[best_feat]
my_tree = {best_feat_label: {}}
del(labels[best_feat])
feat_values = [example[best_feat] for example in dataset]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
my_tree[best_feat_label][value] = create_tree(split_dataset(dataset, best_feat, value), sub_labels)
return my_tree
def classify(input_tree, feat_labels, test_vec):
first_str = list(input_tree.keys())[0]
second_dict = input_tree[first_str]
feat_index = feat_labels.index(first_str)
for key in second_dict.keys():
if test_vec[feat_index] == key:
if type(second_dict[key]).__name__ == 'dict':
class_label = classify(second_dict[key], feat_labels, test_vec)
else:
class_label = second_dict[key]
return class_label
def store_tree(input_tree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(input_tree, fw)
fw.close()
def grab_tree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
def test():
my_dataset, labels = create_dataset()
print(calc_shannon_ent(my_dataset))
my_dataset[0][-1] = 'maybe'
print(calc_shannon_ent(my_dataset))
def test2():
my_dat, labels = create_dataset()
print(split_dataset(my_dat, 0, 1))
print(split_dataset(my_dat, 0, 0))
def test3():
my_dat, labels = create_dataset()
print(choose_best_feature_to_split(my_dat))
print(my_dat)
def test4():
my_dat, labels = create_dataset()
my_tree = create_tree(my_dat, labels)
print(my_tree)
def test5():
my_dat, labels = create_dataset()
my_tree = treePlotter.retrieve_tree(0)
print(classify(my_tree, labels, [1, 0]))
print(classify(my_tree, labels, [1, 1]))
def test6():
my_tree = treePlotter.retrieve_tree(0)
store_tree(my_tree, 'classifierStorage.txt')
tree = grab_tree('classifierStorage.txt')
print(tree)
def test7():
fr = open('./lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_tree = create_tree(lenses, lenses_labels)
print(lenses_tree)
treePlotter.create_plot(lenses_tree)
if __name__ == '__main__':
test7()
import matplotlib.pyplot as plt
decision_node = dict(boxstyle='sawtooth', fc='0.8')
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plot_node(node_txt, center_pt, parent_pt, node_type):
create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction',
xytext=center_pt, textcoords='axes fraction',
va="center", ha="center", bbox=node_type, arrowprops=arrow_args)
def create_plot():
fig = plt.figure(1, facecolor='white')
fig.clf()
create_plot.ax1 = plt.subplot(111, frameon=False)
plot_node(U'决策节点', (0.5, 0.1), (0.1, 0.5), decision_node)
plot_node(U'叶节点', (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
def get_num_leafs(my_tree):
num_leafs = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
num_leafs += get_num_leafs(second_dict[key])
else:
num_leafs += 1
return num_leafs
def get_tree_depth(my_tree):
max_depth = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
this_depth = 1 + get_tree_depth(second_dict[key])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
def retrieve_tree(i):
"""
输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
:param i:
:return:
"""
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return list_of_trees[i]
def test():
my_tree = retrieve_tree(0)
print(get_num_leafs(my_tree))
print(get_tree_depth(my_tree))
def plot_mid_text(cntr_pt, parent_pt, txt_string):
x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
create_plot.ax1.text(x_mid, y_mid, txt_string)
def plot_tree(my_tree, parent_pt, node_txt):
num_leafs = get_num_leafs(my_tree)
depth = get_tree_depth(my_tree)
first_str = list(my_tree.keys())[0]
cntr_pt = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 /plot_tree.total_w, plot_tree.y_off)
plot_mid_text(cntr_pt, parent_pt, node_txt)
plot_node(first_str, cntr_pt, parent_pt, decision_node)
second_dict = my_tree[first_str]
plot_tree.y_off = plot_tree.y_off - 1.0 / plot_tree.total_d
for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
plot_tree(second_dict[key], cntr_pt, str(key))
else:
plot_tree.x_off = plot_tree.x_off + 1.0 / plot_tree.total_w
plot_node(second_dict[key], (plot_tree.x_off, plot_tree.y_off), cntr_pt, leaf_node)
plot_mid_text((plot_tree.x_off, plot_tree.y_off), cntr_pt, str(key))
plot_tree.y_off = plot_tree.y_off + 1.0 / plot_tree.total_d
def create_plot(in_tree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_tree.total_w = float(get_num_leafs(in_tree))
plot_tree.total_d = float(get_tree_depth(in_tree))
plot_tree.x_off = -0.5 / plot_tree.total_w
plot_tree.y_off = 1.0
plot_tree(in_tree, (0.5, 1.0), '')
plt.show()
def test2():
my_tree = retrieve_tree(0)
create_plot(my_tree)
my_tree['no surfacing'][3] = 'maybe'
create_plot(my_tree)
if __name__ == '__main__':
test2()