Python实现决策树算法 C4.5和ID3算法

本文以python语言实现了C4.5和ID3算法,默认为C4.5算法,若要使用ID3算法,将函数 entropy()最后的返回值改变一下即可,即注释掉C4.5那行代码,启用ID3那行代码即可。

将源代码保存为python文件,命名为c45.py,最后一个参数为数据的路径,可自由设置,参考以下运行方式:

python c45.py data.txt

特别感谢:

点击打开链接

源代码如下:

#!/usr/bin/python
# -*- coding: UTF-8 -*-
__author__ = 'Administrator'
########      C4.5  ID3  finished!!        ######
#################  (tm_year=2016, tm_mon=3, tm_mday=15, tm_hour=22, tm_min=56, tm_sec=56, tm_wday=1, tm_yday=75, tm_isdst=0)  ################
import re
import math
import sys

mini_size = 1 #### the minimum size of the nodes, the nodes will not be splited in the next though it is not fully just one type
DataLength = 100  ### the length of data items
used = [0 for i in range(DataLength)] ### attribute used or not
ended = [0 for i in range(DataLength)]  #### if the nodes will be splited in the next
tp = [-1 for i in range(DataLength)]  #### 1 - yes, 0 - no

class node:
    def __init__(self):
        self.value = ''
        self.father = 0
        self.com = 0   ### comes from which attribute
        self.items = set()
        # self.

lg=wd=0
def entropy(dt,values,node,i):
    #for i in range(wd):
    n = len(values[i])
    ls = list(values[i])
    # print ls
    pos = [0 for j in range(n)]
    neg = [0 for j in range(n)]
    for j in node.items:
        a = ls.index(dt[j][i])
        if dt[j][-1] == 'Yes':
            pos[a] += 1
        else:
            neg[a] += 1
    all = 0.0
    sp = 0.0
    for j in range(n):
        all += pos[j]+neg[j]

    for j in range(n):
        if pos[j]+neg[j] ==0:
            continue
        sp -= float(pos[j]+neg[j])/all * math.log(float(pos[j]+neg[j])/all)
    # print all
    s = 0.0
    for j in range(n):
        if pos[j]==0 or neg[j] == 0:
            continue
        s -= (pos[j]+neg[j])/all*( float(pos[j])/(pos[j]+neg[j])*math.log(float(pos[j])/(pos[j]+neg[j])) +\
                                   float(neg[j])/(pos[j]+neg[j])*math.log(float(neg[j])/(pos[j]+neg[j])) )
    #print values[i],pos,neg
    #print s,sp,s/sp
    return s/sp  ### C4.5
    #return s   ### ID3

def ens(dt,values,node):
    ps = ng = 0
    for j in node.items:
        if dt[j][-1] =='Yes':
            ps+=1
        else:
            ng+=1
    #print 'ens',ps,ng
    if ps==0:
        return 1
    if  ng==0:
        return 0
    return float(ps)/(ps+ng)*math.log(float(ps)/(ps+ng)) + float(ng)/(ps+ng)*math.log(float(ng)/(ps+ng))

if __name__ == '__main__':
    #for a in sys.argv:
    #    print a
    file = "c45_data.txt"
    if  len(sys.argv)>1:
        file = sys.argv[1]
    dt = [0 for i in range(DataLength)]
    fp = open(file,"r")
    i=0
    for line in fp:
        line = re.sub(r"\n\r","",line)
        ls = line.split()
        dt[i] = ls
        i+=1
        # print i
    lg = i
    wd = len(dt[0])
    # print lg,wd
    values = [set() for i in range(wd)]
    for i in range(lg):
        for j in range(wd):
            values[j].add(dt[i][j])
    # print values
    root =node()
    root.father = -1
    root.com=-1
    for i in range(lg):
        root.items.add(i)
    # print root.items
    tree = [node() for i in range(DataLength)]
    tree[0] = root
    #print values  ### the values of each attributes

    order = -1
    now = 0
    flg = 0
    while (order<=now):
        order += 1
        flg = 0

        if len(tree[order].items)<=mini_size:
            #print "mini_size",mini_size
            ps=ng=0
            for j in tree[order].items:
                if dt[j][-1]=='Yes':
                    ps+=1
                else:
                    ng+=1
            if ps>=ng:
                tp[order] = 1  ##############
            else:
                tp[order] = 0
            #print tp[order]
            continue

        ls = [1.0 for i in range(wd-1)]
        rt = -ens(dt,values,tree[order])
        for i in range(wd-1):
            if used[i] ==1:
                continue
            flg = 1
            ls[i] = entropy(dt,values,tree[order],i)
            #print ls[i]
        #print max(ls)
        #print ls,rt,min(ls)
        if min(ls)>=rt or flg==0:
            #print "rt",ls,rt,flg
            ps=ng=0
            for j in tree[order].items:
                if dt[j][-1]=='Yes':
                    ps+=1
                else:
                    ng+=1
            if ps>=ng:
                tp[order] = 1  ##############
            else:
                tp[order] = 0
            continue
        if min(ls)==0:
            #print '0'
            ps=ng=0
            for j in tree[order].items:
                if dt[j][-1]=='Yes':
                    ps+=1
                else:
                    ng+=1
            if ps>=ng:
                tp[order] = 1  ##############
            else:
                tp[order] = 0
        i = ls.index(min(ls))
        used[i] = 1
        #print i
        ll = list(values[i])
        n = len(ll)
        #print rt,ls
        #print "hhh",tree[order].items,n,i
        for j in tree[order].items:
            k = ll.index(dt[j][i])
            tree[now+k+1].items.add(j)
            tree[now+k+1].value = ll[k]
        for j in range(n):
            tree[now+j+1].father = order
            tree[now+j+1].com = i
        now += n
        #print 'hello world',now

    '''    '''
    print now+1,"nodes in all"
    for i in range(now+1):
        print i,'\tfather:',tree[i].father,'\tattribute: ',tree[i].com,"\tvalue:",tree[i].value
        print tree[i].items,tp[i],'\n'
        #print "\n"
    for i in range(wd-1):
        print ''#,entropy(dt,values,root,i)




你可能感兴趣的:(数据挖掘,python,决策树,C4.5,ID3)