问题:编程实现基于信息熵(信息增益)进行划分进行划分选择的决策树算法,并为表4.3(西瓜数据集3.0)中数据生成一棵决策树。
代码生成结果与书本结果基本一致,唯有(触感=硬滑)和(触感=软粘)时我的答案分别是(坏瓜)和(好瓜),而书本答案恰好相反。这里应为书本错误,因为根据数据人肉眼判定,稍糊硬滑的数据都为否,稍糊软粘数据都为是。如果有和我结论不一致的同学,欢迎指正!感谢ICS_的指出,在周老师的主页勘误表里已经修改此处错误,第一版第五次印刷及以后的书应该都没有此问题。
数据为中文,所以由中文编码问题导致很多trick,部分已在代码中指出。
python代码如下:
#coding: utf-8
import math
class Node:
def __init__(self, divided_by=None, condition=None, sons=[], label=None):
self.divided_by = divided_by # node is divided by this, None if it's leaf
self.condition = condition # brunch condition, a float if divided_by type is float, else a list map to sons
self.sons = sons # pointer of this node's sons
self.label = label # None if it's not leaf, else classification result of this leaf node
def cal_ent(count, total_cnt):
if total_cnt == 0:
return 0
prob = [float(count[k])/total_cnt for k in count]
ent = -sum([math.log(i, 2)*i for i in prob]) # log(n[, base] )
return ent
def cal_info_gain(entd, divided_res, total_cnt, Y):
if total_cnt == 0:
return 0
# print entd
# for i in divided_res:
# print i
# print total_cnt
return entd - sum(float(len(data))/total_cnt*cal_ent(cal_count(data, Y), len(data)) for data in divided_res)
def cal_count(data_rem, Y):
count = {}
for i in data_rem:
kind = Y[i]
count[kind] = count.get(kind, 0) + 1
return count
def build_tree(attr_rem, data_rem, X, Y, attr_dict, trees, fa_label):
# if attr_rem is None:
# print None
# else:
# for i in attr_rem:
# print i,
# print
# print data_rem
# print fa_label
divided_by = None
sons = []
label = None
condition = None
count = cal_count(data_rem, Y) # number of every kind
total_cnt = len(data_rem) # total number of remain data
max_kind = max(count, key=lambda x: count[x]) # kind which has the largest number
ent = cal_ent(count, total_cnt)
if len(data_rem) == 0:
label = fa_label
elif attr_rem is None or len(count) == 1: # remain attribute is empty or all data is one same kind, then this is a leaf node
label = max_kind
else:
best_divided_res = []
best_condition = []
best_divided_by = None
best_info_gain = -100
for attr in attr_rem:
is_float = type(X[data_rem[0]][attr]) == float
if is_float:
points = sorted(attr_dict[attr])
divided_points = [(points[i]+points[i+1])/2 for i in range(len(points)-1)]
for divided_point in divided_points:
divided_res = []
divided_res.append([i for i in data_rem if X[i][attr] <= divided_point])
divided_res.append([i for i in data_rem if X[i][attr] > divided_point])
info_gain = cal_info_gain(ent, divided_res, total_cnt, Y)
if info_gain > best_info_gain:
best_info_gain = info_gain
best_divided_res = divided_res
best_divided_by = attr
best_condition = divided_point
else:
divided_res = []
for name in attr_dict[attr]:
divided_res.append([j for j in data_rem if X[j][attr]==name])
info_gain = cal_info_gain(ent, divided_res, total_cnt, Y)
if info_gain > best_info_gain:
best_info_gain = info_gain
best_divided_res = divided_res
best_divided_by = attr
best_condition = attr_dict[attr]
divided_by = best_divided_by
is_float = type(X[data_rem[0]][divided_by]) == float
if not is_float:
attr_rem = [i for i in attr_rem if i != divided_by] # can't use remove! behaviour not expected when dealing with utf-8
sons = [build_tree(attr_rem, i, X, Y, attr_dict, trees, max_kind) for i in best_divided_res]
condition = best_condition
trees.append(Node(divided_by, condition, sons, label))
return len(trees)-1
def fit(X, Y):
attr_dict = {}
for i in range(len(X[0])):
attr_dict[X[0][i]] = list(set([X[j][i] for j in range(1,len(X))])) # details of all attributes
attr_rem = X[0] # remain attribute on this node
data_rem = range(1, len(X)) # index of remain data on this node
X = [dict(zip(X[0], row)) for row in X] # convert to more convinient data structure
# for i in X:
# for j in i.items():
# print j[0], j[1]
# print
trees = [] # records of the built tree
root = build_tree(attr_rem, data_rem, X, Y, attr_dict, trees, None) # build the tree recursively
display(trees, root)
def display(trees, root):
que = [root]
while(que):
pos = que[0]
del que[0]
divided_by = trees[pos].divided_by
condition = trees[pos].condition
sons = trees[pos].sons
label = trees[pos].label
print 'id:'
print pos
if divided_by is None:
print 'label'
print label
else:
print 'divided_by'
print divided_by
print 'condition:'
if type(condition) == list:
for i in condition:
print i,
print
else:
print condition
print 'sons:'
for i in sons:
print i,
print
for i in sons:
que.append(i)
print
input_path = "西瓜数据集3.csv"
file = open(input_path.decode('utf-8'))
filedata = [line.strip('\n').split(',') for line in file]
filedata = [[float(i) if '.' in i.decode('utf-8') else i for i in row ] for row in filedata] # change decimal from string to float
X = [row[1:-1] for row in filedata] # attributes
Y = [row[-1] for row in filedata] # class label
fit(X, Y)