C4.5代码实现

C4.5决策树简单实现版本

最近看书,自学机器学习相关算法,看到的第一个算法是C4.5决策树,所以找了UCI的大豆预测数据集,看了一些博客的介绍,参考了别人的实现之后,尝试着自己实现了一版,请大家批评指正。

数据集地址

http://archive.ics.uci.edu/ml/machine-learning-databases/soybean/

代码如下

// An highlighted block
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 22 11:06:47 2019
"""

# c4.5
# dataset:soybean
# 计算信息熵gain=entropy-H(条件熵)
# 计算gain ratio=gain/entropy(A)
# 计算geni系数
# pruning

#计算数据集的经验熵H(D)
import math
# 数据集格式:最后一列是Class label
# 输入:data为列表
def calc_ent(data):
    result=0.0
    label_set=set(data)
    for label in label_set:
        p=data.count(label)/len(data)
        result-=p*math.log2(p)
    return result
# 计算条件熵
# 计算某个属性的条件熵
def calc_conditional_ent(dataset,attribute):
    result=0.0
    values=set(instance[attribute] for instance in dataset)
    for value in values:
        data=[instance[-1] for instance in dataset if instance[attribute]==value]
        result +=len(data)/len(dataset)*calc_ent(data)
    return result
# 计算信息增益
def gain(dataset,attribute):
    classes=[instance[-1] for instance in dataset]
    return calc_ent(classes)-calc_conditional_ent(dataset,attribute)
# 分裂信息
# 引入属性的分裂信息来调节信息增益
# 离散属性划分数据集
def split_info(dataset,attribute):
    result=0.0
    values=set(instance[attribute] for instance in dataset)
    for value in values:
        data=[instance[-1] for instance in dataset if instance[attribute]==value]
        result-=len(data)/len(dataset)*math.log2(len(data)/len(dataset))
    return result

# 计算信息增益
def gain_ratio(attribute,dataset):
    temp=split_info(dataset,attribute)
    if temp==0:
        return gain(dataset,attribute)
    return gain(dataset,attribute)/split_info(dataset,attribute)

# gini指标
def gini(dataset):
    result=1
    classes=[instance[-1] for instance in dataset]
    labels=set(classes)
    for label in labels:
        p=classes.count(label)/len(dataset)
        result*=p*p
    return 1-result

# 分离数据集
# 返回特征值==value的数据子集
def split_dataset(dataset,attribute,value):
    result=[]
    for instance in dataset:
        if instance[attribute]==value:
            result.append(instance[:attribute]+instance[attribute+1:])
    return result
# 连续属性划分数据集
# 划分数据集, axis:按第几个特征划分, value:划分特征的值, gr_or_le:大于还是小于等于
def split_dataset_c(dataset,attribute,value,gr_or_le):
    result=[]
    if gr_or_le=='l':
        for instance in dataset:
            if instance[attribute]<=value:
                result.append(instance[:attribute]+instance[attribute+1:])
    else:
        for instance in dataset:
            if instance[attribute]>value:
                result.append(instance[:attribute]+instance[attribute+1:])
    return result
# 选择划分数据准则
# ID3按照最大信息增益来划分数据,C4.5按照最大信息增益比来划分数据
# CART使用基尼系数
# 循环所有features,选择一个使得增益/增益比最大的特征,返回其索引
def choose_best_attribute(dataset):
    num_features=len(dataset[0])-1
    max_ratio=0
    for attribute in range(num_features):
        temp=gain_ratio(attribute,dataset)
        if temp>max_ratio:
            max_ratio=temp
            result=attribute
    return result
# 考虑连续取值时,选择最佳的属性
def choose_best_attribute_c(dataset,label_property):
    num_features=len(dataset[0])-1
    best_feat=-1
    best_value=None
    best_value_i=None
    max_ratio=0
    ent_dataset=calc_ent([i[-1] for i in dataset])
    for attribute in range(num_features):
        if label_property[attribute]==0:# 离散值属性
            info_gain=gain(dataset,attribute)
        else:#如果取值是连续的
        # 如果取值是连续的,需要对每所有取值排序,选择两两中点划分数据集,计算信息增益/信息增益比
            new_gain=0
            attribute_vals=[instance[attribute] for instance in dataset]
            uniq_vals=set(attribute_vals)
            sorted_uniq_vals=sorted(list(uniq_vals))
            for j in range(len(sorted_uniq_vals)-1):
                part_val=(sorted_uniq_vals[j]+sorted_uniq_vals[j+1])/2 #计算划分点
                dataset_left=split_dataset_c(dataset,attribute,part_val,'l')
                dataset_right=split_dataset_c(dataset,attribute,part_val,'g')
                p_left=len(dataset_left)/len(dataset)
                p_right=len(dataset_right)/len(dataset)
                entropy=p_left*calc_ent([i[-1] for i in dataset_left])+p_right*calc_ent([i[-1] for i in dataset_right])
                temp=ent_dataset-entropy
                if temp>new_gain:
                    new_gain=temp
                    best_value_i=part_val
            info_gain=new_gain
        if info_gain>max_ratio:
            max_ratio=info_gain
            best_feat=attribute
            best_value=best_value_i
    return best_feat,best_value
# 构建决策树
# 每划分一次,dataset中的数据就减少一些
# 直到全部数据划分完毕或者全部剩下的数据全都属于同一类别
def create_tree(dataset,labels):
    class_list=[instance[-1] for instance in dataset]
    if len(set(class_list))==1:#只有一种类别
        return class_list[0]
    # 如果遍历完所有的特征,但是还有一些数据集不能确定类别,则采用出现最多的类别
    if len(dataset[0])==1:
        temp=[instance[0] for instance in dataset]
        class_count={}
        for value in temp:
            if value not in class_count.keys():
                class_count[value]=0
            class_count[value]+=1
        return sorted(class_count.items(),reverse=True)[0][0]
    best_attribute=choose_best_attribute(dataset)
    print(best_attribute,labels)
    best_feature = labels[best_attribute]
    my_tree={best_feature:{}}
    del labels[best_attribute]
    new_labels=labels[:]
    feature_values=[instance[best_attribute] for instance in dataset]
    unique_values=set(feature_values)
    for v in unique_values:
        subdataset=split_dataset(dataset,best_attribute,v)
        my_tree[best_feature][v]=create_tree(subdataset,new_labels)
    return my_tree
# 以上程序中只考虑某个属性的离散值属性
# 对于某个可能取连续值的属性,需要特殊考虑
# 可以采用的方法由连续属性离散化
# 可以采用二分法将连续属性离散化处理
# 假设样本集D有连续属性a有N个不同的取值
# 对这些值从小到大排序,得到属性值集合
# 把区间(ai,a_i+1)的中点作为候选划分点,可以得到n-1个元素的划分集合
# 基于每个划分点t,可以将样本分成D_t+(>t)和D_t-(<=t)
# 对于每个划分点计算其信息增益/信息增益比即可
def create_tree_c(dataset,labels,label_property):
    class_list=[instance[-1] for instance in dataset]
    if len(set(class_list))==1:#只有一种类别
        return class_list[0]
    # 如果遍历完所有的特征,但是还有一些数据集不能确定类别,则采用出现最多的类别
    if len(dataset[0])==1:
        temp=[instance[0] for instance in dataset]
        class_count={}
        for value in temp:
            if value not in class_count.keys():
                class_count[value]=0
            class_count[value]+=1
        return sorted(class_count.items(),reverse=True)[0][0]
    best_attribute,best_val=choose_best_attribute_c(dataset,label_property)
    if best_attribute==-1:#无法选出最好的属性用来分类
        temp=[instance[0] for instance in dataset]
        class_count={}
        for value in temp:
            if value not in class_count.keys():
                class_count[value]=0
            class_count[value]+=1
        return sorted(class_count.items(),reverse=True)[0][0] #返回出现出现次数最多的类别
    best_feature = labels[best_attribute]
    #del labels[best_attribute]
    new_labels=labels[:best_attribute]+labels[best_attribute+1:]
    if label_property[best_attribute]==0:#取值离散
        my_tree={best_feature:{}}
        feature_values=[instance[best_attribute] for instance in dataset]
        unique_values=set(feature_values)
        for v in unique_values:
            subdataset=split_dataset(dataset,best_attribute,v)
            my_tree[best_feature][v]=create_tree_c(subdataset,new_labels,label_property)
    else:#取值连续
        new_feature=best_feature+'>'+str(best_val)
        my_tree={new_feature:{}}
        #del labels[best_attribute]
        new_labels=labels[:best_attribute]+labels[best_attribute+1:]
        sub_dataset_left=split_dataset_c(dataset,best_attribute,best_val,'l')
        sub_dataset_right=split_dataset_c(dataset,best_attribute,best_val,'g')
        my_tree[new_feature]['not']=create_tree_c(sub_dataset_left,new_labels,label_property)
        my_tree[new_feature]['yes']=create_tree_c(sub_dataset_right,new_labels,label_property)
        
    return my_tree

dataset=[[1,0.2,'yes'],[1,0.1,'yes'],[1,0.5,'no'],[0,0.6,'no'],[0,0.8,'no']]
labels=['no surfacing','flippers','fish']
label_property=[0,1,0]
# 处理soybean数据集
soybean_data = []
soybean_labels =["date","plant-stand","precip","temp","hail","crop-hist","area-damaged","severity",
                 "seed-tmt","germination","plant-growth",
                 "leaves","leafspots-halo","leafspots-marg","leafspot-size","leaf-shread","leaf-malf",
                 "leaf-mild","stem","lodging","stem-cankers",
                 "canker-lesion","fruiting-bodies","external decay","mycelium","int-discolor","sclerotia","fruit-pods",
                 "fruit spots","seed","mold-growth",
                 "seed-discolor","seed-size","shriveling","roots"]
label_property =[0 for i in range(35)]
with open(r"./soybean-large.data") as f:
    for line in f.readlines():
        temp=line.strip().split(",")
        instance=[int(x) if x!='?' else 0 for x in temp[1:]]
        soybean_data.append(instance)
        soybean_data[-1][-1]=temp[0]
soybean_tree_one=create_tree_c(soybean_data,soybean_labels,label_property)
# 未剪枝时有95个叶子节点,147个节点

# 测试
def test():
    soybean_test=[]
    err=0
    with open(r"./soybean-large.test") as f:
        for line in f.readlines():
            temp=line.strip().split(",")
            instance=[int(x) if x!='?' else 0 for x in temp[1:]]
            soybean_test.append(instance)
            soybean_test[-1][-1]=temp[0]
    for instance in soybean_test:
        predict_result=tree_result(instance,soybean_labels,label_property,soybean_tree_one)
        if predict_result != instance[-1]:
            err+=1
    print("error rate is ",err/len(soybean_test))
# 正确率80%    
# 采用验证集,测试正确率
def tree_result(instance,labels,label_property,tree):
    label_str=list(tree.keys())[0]
    index = labels.index(label_str)
    value = instance[index]
    if value in tree[label_str].keys(): 
        next_node = tree[label_str][value]
    else: # 如果出现了决策树不曾覆盖的分支
        return None
    if type(next_node).__name__=='dict':
        return tree_result(instance,labels,label_property,next_node)
    else:
        return next_node

# 使用文本注解绘制树节点
import matplotlib.pylab as plt
decision_node =  dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8") # 定义文本框和箭头格式
arrow_args=dict(arrowstyle="<-")


def plot_node(nodetxt,center_pt,parent_pt,nodetype):
    createPlot.ax1.annotate(nodetxt,xy=parent_pt,xycoords='axes fraction',
                            xytext=center_pt,textcoords='axes fraction',
                            va="center",ha="center",bbox=nodetype,arrowprops=arrow_args)
    
def createPlot(intree):
    fig=plt.figure(1,facecolor="white")
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops) #frameon=False 决定是否有外框
    plot_tree.totalW=float(get_num_leafs(intree))
    plot_tree.totalD=float(get_tree_depth(intree))
    plot_tree.xOff=-0.5/plot_tree.totalW
    plot_tree.yOff=1.0
    plot_tree(intree,(0.5,1.0),'')
    plt.show()
# 为了绘制树形图,需要确定树的高度和叶子节点的数目
def get_num_leafs(mytree):
    num_leafs=0
    first_str=list(mytree.keys())[0]
    second_dict=mytree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__=='dict':
            num_leafs =num_leafs+ get_num_leafs(second_dict[key])+1
        else:
            num_leafs+=1
    return num_leafs+1# +1算上根节点
# 确定树的深度
def get_tree_depth(mytree):
    max_depth=0
    first_str=list(mytree.keys())[0]
    second_dict=mytree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__=='dict':
            this_depth=1+get_tree_depth(second_dict[key])
        else:
            this_depth=1
        if this_depth>max_depth:
            max_depth=this_depth
    return max_depth
# 画出决策树
def plot_mid_text(cntr_pt,parent_pt,text_string):
    # 在父子节点之间填充文本信息
    x_mid=(parent_pt[0]-cntr_pt[0])/2.0+cntr_pt[0]
    y_mid=(parent_pt[1]-cntr_pt[1])/2.0+cntr_pt[1]
    createPlot.ax1.text(x_mid,y_mid,text_string)
    
def plot_tree(mytree,parent_pt,node_txt):
    num_leafs=get_num_leafs(mytree)
    depth = get_tree_depth(mytree)
    first_str=list(mytree.keys())[0]
    cntr_pt=(plot_tree.xOff+(1.0+float(num_leafs))/2.0/plot_tree.totalW,plot_tree.yOff)
    plot_mid_text(cntr_pt,parent_pt,node_txt)
    plot_node(first_str,cntr_pt,parent_pt,decision_node)
    secondDict=mytree[first_str]
    plot_tree.yOff=plot_tree.yOff-1.0/plot_tree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plot_tree(secondDict[key],cntr_pt,str(key))
        else:
            plot_tree.xOff=plot_tree.xOff+1.0/plot_tree.totalW
            plot_node(secondDict[key],(plot_tree.xOff,plot_tree.yOff),cntr_pt,leaf_node)
            plot_mid_text((plot_tree.xOff,plot_tree.yOff),cntr_pt,str(key))
    plot_tree.yOff=plot_tree.yOff+1.0/plot_tree.totalD
# 使用pickle模块序列化对象
# 序列化对象可在磁盘上保存对象,并在需要的时候读取出来。
# 任何对象都可以执行序列化操作
def store_tree(intree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dump(intree,fw)
    fw.close()
def grab_tree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)        

你可能感兴趣的:(deep,learning)