python实现西瓜书《机器学习》习题4.4基尼指数决策树,预剪枝及后剪枝

大神代码:https://blog.csdn.net/Snoopy_Yuan/article/details/69223240
昨天画不出树有点烦躁,随便找了百度了一点点,还是画不出来。
今天这道题,其实就是把信息增益换成基尼指数,本质上的构造树逻辑是一致的。
不过源代码有个小错误,在上面链接里已经评论了,好奇宝宝可以自己去看

不过,奇葩的是前后剪枝算出来的准确率一毛一样,估计程序里还有问题,以后再扣吧。。。

主程序gini_decision_tree.py

#https://blog.csdn.net/Snoopy_Yuan/article/details/69223240

import pandas as pd

#data_file_encode="gb18030"   #gb18030支持汉字和少数民族字符,是一二四字节变长编码。这么用的时候with open需要增加encoding参数,但会报错gb18030不能解码
# with open相当于打开文件,保存成str对象,如果出错则关闭文件。参数r表示只读
with open("/Users/huatong/PycharmProjects/Data/watermelon_33.csv",mode="r") as data_file:
    df=pd.read_csv(data_file)

import decision_tree

# 取出训练集,iloc是根据数字索引取出对应行的信息,drop是删除这些行之后剩余的表格
index_train = [0, 1, 2, 5, 6, 9, 13, 14, 15, 16]   #和书上80页的训练样本相同

df_train = df.iloc[index_train]
df_test = df.drop(index_train)


# generate a full tree
root = decision_tree.TreeGenerate(df_train)
#decision_tree.DrawPNG(root, "decision_tree_full.png")  画不出来 先注释掉
print("accuracy of full tree: %.3f" % decision_tree.PredictAccuracy(root, df_test))

# 预剪枝
root = decision_tree.PrePurn(df_train, df_test)
#decision_tree.DrawPNG(root, "decision_tree_pre.png")
print("accuracy of pre-purning tree: %.3f" % decision_tree.PredictAccuracy(root, df_test))

# 后剪枝,先生成树,再从底部节点开始分析
root = decision_tree.TreeGenerate(df_train)
decision_tree.PostPurn(root, df_test)
#decision_tree.DrawPNG(root, "decision_tree_post.png")
print("accuracy of post-purning tree: %.3f" % decision_tree.PredictAccuracy(root, df_test))

# 5折交叉分析
accuracy_scores = []
n = len(df.index)
k = 5
for i in range(k):
    m = int(n / k)
    test = []
    for j in range(i * m, i * m + m):
        test.append(j)

    df_train = df.drop(test)
    df_test = df.iloc[test]
    root = decision_tree.TreeGenerate(df_train)  # generate the tree
    decision_tree.PostPurn(root, df_test)  # post-purning

    # test the accuracy
    pred_true = 0
    for i in df_test.index:
        label = decision_tree.Predict(root, df[df.index == i])
        if label == df_test[df_test.columns[-1]][i]:
            pred_true += 1

    accuracy = pred_true / len(df_test.index)
    accuracy_scores.append(accuracy)

# print the prediction accuracy result
accuracy_sum = 0
print("accuracy: ", end="")
for i in range(k):
    print("%.3f  " % accuracy_scores[i], end="")
    accuracy_sum += accuracy_scores[i]
print("\naverage accuracy: %.3f" % (accuracy_sum / k))


decision_tree.py

#被主程序执行treeGenerate时候调用,def用于定义函数
#节点类,包含①当前节点的属性,例如纹理清晰? ②节点所属分类,只对叶子节点有效 ③向下划分的属性取值例如色泽乌黑青绿浅白


class Node(object):   #新式类
    def __init__(self,attr_init=None,label_init=None,attr_down_init={}):   #注意类的特殊函数前后有两个下划线
        self.attr=attr_init
        self.label=label_init
        self.attr_down=attr_down_init

#主函数,输入参数为数据集,输出参数为决策树根节点Node
def TreeGenerate(df):
    new_node=Node(None,None,{})
    label_arr=df[df.columns[-1]]   #好瓜这列数值,df.columns[-1]是最后一列
    label_count=NodeLabel(label_arr)
    if label_count:  #类别统计结果不为空
        new_node.label=max(label_count,key=label_count.get) #取类别数目最多的类,get是返回键值
        #如果样本全属于同一类别则直接返回叶节点,或如果样本属性集A为空则返回叶节点并标记类别为类别数最多的类,但如果样本属性取值相同怎么处理?
        if len(label_count)==1 or len(label_arr)==0:
            return new_node
        #根据基尼指数选择最优划分属性
        new_node.attr,div_value=OptAttr_Gini(df)
        #如果属性值为空,删除当前属性再递归
        if div_value==0:
            value_count=ValueCount(df[new_node.attr])
            for value in value_count:
                df_v=df[df[new_node.attr].isin([value])]
                dv_v=df_v.drop(new_node.attr,1)
                new_node.attr_down[value]=TreeGenerate(df_v)
        else:
            value_l="<=%.3f"%div_value
            value_r=">%.3f"%div_value
            df_v_l=df[df[new_node.attr]<=div_value]   #左孩子
            df_v_r=df[df[new_node.attr]>div_value]    #右孩子
            new_node.attr_down[value_l] = TreeGenerate(df_v_l)   #继续分
            new_node.attr_down[value_r] = TreeGenerate(df_v_r)
    return new_node


#统计样本包含的类别和每个分类的个数,输入参数是分类标签序列,输出序列中包含的类别和各类别总数
def NodeLabel(label_arr):
    label_count={}
    for label in label_arr:
        if label in label_count: label_count[label]+=1
        else:label_count[label]=1
    return label_count


#寻找最优划分属性,输入参数为数据集,输出参数为属性opt_attr和划分取值div_value,div_value对离散变量取值为0,对连续变量取实际值
def OptAttr_Gini(df):
    gini_index=float('Inf')
    for attr_id in df.columns[1:-1]:
        gini_index_tmp,div_value_tmp=GiniIndex(df,attr_id)
        if gini_index_tmp a0:  # need branching
                for value in value_count:
                    df_v = df_train[df_train[new_node.attr].isin([value])]  # get sub set
                    df_v = df_v.drop(new_node.attr, 1)
                    new_node.attr_down[value] = TreeGenerate(df_v)
            else:
                new_node.attr = None
                new_node.attr_down = {}

        else:  # continuous variable # left and right child
            value_l = "<=%.3f" % div_value
            value_r = ">%.3f" % div_value
            df_v_l = df_train[df_train[new_node.attr] <= div_value]  # get sub set
            df_v_r = df_train[df_train[new_node.attr] > div_value]

            # for child node
            new_node_l = Node(None, None, {})
            new_node_r = Node(None, None, {})
            label_count_l = NodeLabel(df_v_l[df_v_r.columns[-1]])
            label_count_r = NodeLabel(df_v_r[df_v_r.columns[-1]])
            new_node_l.label = max(label_count_l, key=label_count_l.get)
            new_node_r.label = max(label_count_r, key=label_count_r.get)
            new_node.attr_down[value_l] = new_node_l
            new_node.attr_down[value_r] = new_node_r

            # calculating to check whether need further branching
            a1 = PredictAccuracy(new_node, df_test)
            if a1 > a0:  # need branching
                new_node.attr_down[value_l] = TreeGenerate(df_v_l)
                new_node.attr_down[value_r] = TreeGenerate(df_v_r)
            else:
                new_node.attr = None
                new_node.attr_down = {}

    return new_node


#后剪枝
def PostPurn(root, df_test):
    '''
    pre-purning to generating a decision tree

    @param root: Node, root of the tree
    @param df_test: dataframe, the testing set for purning decision
    @return accuracy score through traversal the tree
    '''
    # leaf node
    if root.attr == None:
        return PredictAccuracy(root, df_test)

    # calculating the test accuracy on children node
    a1 = 0
    value_count = ValueCount(df_test[root.attr])
    for value in list(value_count):
        df_test_v = df_test[df_test[root.attr].isin([value])]  # get sub set
        if value in root.attr_down:  # root has the value
            a1_v = PostPurn(root.attr_down[value], df_test_v)
        else:  # root doesn't have value
            a1_v = PredictAccuracy(root, df_test_v)
        if a1_v == -1:  # -1 means no pruning back from this child
            return -1
        else:
            a1 += a1_v * len(df_test_v.index) / len(df_test.index)

    # calculating the test accuracy on this node
    node = Node(None, root.label, {})
    a0 = PredictAccuracy(node, df_test)

    # check if need pruning
    if a0 >= a1:
        root.attr = None
        root.attr_down = {}
        return a0
    else:
        return -1

def DrawPNG(root, out_file):
    import graphviz
    '''
    visualization of decision tree from root.
    @param root: Node, the root node for tree.
    @param out_file: str, name and path of output file
    '''
    try:
        from pydotplus import graphviz
    except ImportError:
        print("module pydotplus.graphviz not found")

    g = graphviz.Dot()  # generation of new dot

    TreeToGraph(0, g, root)
    g2 = graphviz.graph_from_dot_data(g.to_string())

    g2.write_png(out_file)


def TreeToGraph(i, g, root):
    '''
    build a graph from root on
    @param i: node number in this tree
    @param g: pydotplus.graphviz.Dot() object
    @param root: the root node

    @return i: node number after modified
#     @return g: pydotplus.graphviz.Dot() object after modified
    @return g_node: the current root node in graphviz
    '''
    try:
        from pydotplus import graphviz    #pydotplus和graphviz都要安装
    except ImportError:
        print("module pydotplus.graphviz not found")

    if root.attr == None:
        g_node_label = "Node:%d\n好瓜:%s" % (i, root.label)
    else:
        g_node_label = "Node:%d\n好瓜:%s\n属性:%s" % (i, root.label, root.attr)
    g_node = i
    g.add_node(graphviz.Node(g_node, label=g_node_label))

    for value in list(root.attr_down):
        i, g_child = TreeToGraph(i + 1, g, root.attr_down[value])
        g.add_edge(graphviz.Edge(g_node, g_child, label=value))

    return i, g_node

你可能感兴趣的:(python)