在进行案例分析前,先对决策树算法的分类函数进行测试。考虑到构造决策树非常耗时,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。这就需要利用python模块pickle序列化对象将决策树分类算法保存在磁盘中,并在需要的时候读取出来。
1、测试决策树分类算法性能
######################################
#功能:决策树的分类函数
#输入变量:input_tree, feat_labels, test_vec
# 决策树,分类标签,测试数据
#输出变量:class_label 类标签
######################################
def classify(input_tree, feat_labels, test_vec):
first_str = input_tree.keys()[0]
second_dict = input_tree[first_str]
class_label = -1
# index方法用于查找当前列表中第一个匹配first_str变量的索引
feat_index = feat_labels.index(first_str)
for key in second_dict.keys():
if test_vec[feat_index] == key:
if type(second_dict[key]).__name__ == 'dict':
class_label = classify(second_dict[key], feat_labels, test_vec)
else:
class_label = second_dict[key]
return class_label
2、对决策树算法进行存储
######################################
#功能:将决策树存储到磁盘中
#输入变量:input_tree, filename 决策树,存储的文件名
######################################
def store_tree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw) # 序列化,将数据写入到文件中
fw.close()
3、对决策树算法进行读取
######################################
#功能:从磁盘中读取决策树信息
#输入变量:filename 存储的文件名
######################################
def grab_tree(filename):
import pickle
fr = open(filename, 'r')
return pickle.load(fr) # 反序列化
4、代码测试
def main():
my_data, my_labels = create_data_set()
print 'my_data=', my_data
print 'my_labels=', my_labels
class_label = classify(my_tree, my_labels, [1, 1])
print 'class_label=', class_label
store_tree(my_tree, 'classifierStorage.txt')
tree = grab_tree('classifierStorage.txt')
print 'tree=', tree
if __name__ == '__main__':
main()
案例分析:使用决策树预测隐形眼镜类型
隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。而眼科医生需要从age、prescript、astigmatic和tearRate这四个方面对患者进行询问,以此来判断患者佩戴的镜片类型。利用决策树算法,我们甚至也可以帮助人们判断需要佩戴的镜片类型。
在构造决策树前,我们需要获取隐形眼镜数据集,从lenses.txt文件读取。还需要获取特征属性(或者说决策树的决策结点),从代码输入。将数据集和特征属性代入决策树分类算法,就能构造出隐形眼镜决策树,沿着不同分支,我们可以得到不同患者需要的眼镜类型。
代码如下:
fr = open('lenses.txt', 'r')
lenses = [line.strip().split('\t') for line in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_tree = create_tree(lenses, lenses_labels)
print 'lenses_tree=', lenses_tree