使用python语言实现了决策树算法——id3算法,废话不多说,直接贴代码
import math
train={#定义训练集
1:{'outlook':'sunny','temp':'hot','hum':'high','wind':'weak','play':'no'},
2:{'outlook':'sunny','temp':'hot','hum':'high','wind':'strong','play':'no'},
3:{'outlook':'overcast','temp':'hot','hum':'high','wind':'weak','play':'yes'},
4:{'outlook':'rain','temp':'mild','hum':'high','wind':'weak','play':'yes'},
5:{'outlook':'rain','temp':'cool','hum':'normal','wind':'weak','play':'yes'},
6:{'outlook':'rain','temp':'cool','hum':'normal','wind':'strong','play':'no'},
7:{'outlook':'overcast','temp':'cool','hum':'normal','wind':'strong','play':'yes'},
8:{'outlook':'sunny','temp':'mild','hum':'high','wind':'weak','play':'no'},
9:{'outlook':'sunny','temp':'cool','hum':'normal','wind':'weak','play':'yes'},
10:{'outlook':'rain','temp':'mild','hum':'normal','wind':'weak','play':'yes'},
11:{'outlook':'sunny','temp':'mild','hum':'normal','wind':'strong','play':'yes'},
12:{'outlook':'overcast','temp':'mild','hum':'high','wind':'strong','play':'yes'},
13:{'outlook':'overcast','temp':'hot','hum':'normal','wind':'weak','play':'yes'},
14:{'outlook':'rain','temp':'mild','hum':'high','wind':'strong','play':'no'},
}
def info(train):#定义信息量计算方法,传入训练集
total,totalzheng,totalfu,info=0,0,0,0
for key in train.keys():#计算训练样本中yes和no的个数
total+=1
if train[key]['play']=='yes':
totalzheng+=1
elif train[key]['play']=='no':
totalfu+=1
if totalfu==0 or totalzheng==0:
return [total,totalzheng,totalfu,info]#如果全为yes或者全为no,则信息为0
else:
bili1=totalzheng/(totalzheng+totalfu)
bili2=totalfu/(totalzheng+totalfu)
info=bili1*math.log2(bili1)+bili2*math.log2(bili2)#计算公式为正数的比例×log2(正数的比例)
return [total,totalzheng,totalfu,round(info,3)*-1]
def parttrain(train,targetattr,mainattr):#定义分离数组的方法,传入需要分离的训练集,需要分离出来的属性,和该属性所属的字段
returndict={}#定义返回字典
for key in train.keys():
if train[key][mainattr]==targetattr:#如果该条的相应的属性值等于目标属性
returndict[key]=train[key]
return returndict
def attrset(train,attr):#求该属性在该训练集下的集合
resset=[]
for key in train.keys():
resset.append(train[key][attr])#直接加入训练集中该属性下的属性值
resset=set(resset)#去除重复值
return resset
class Tree():
def __init__(self,root):#初始化函数,定义节点值和结点的孩子字典
self.root=root
self.child={}
def addchild(self,attr,dict):#传入属性值和字典,构建孩子字典
self.child[attr]=dict
def show(self):#返回根节点
a={}
a[self.root]=self.child#将孩子字典赋值给根节点
return a
def maxinfo(train,attrs):#定义求该训练集下attrs属性列表中信息增益最大的属性的方法
maxattr=''
maxnum=0
for attr in attrs:#循环所有的属性
attrtibutes=attrset(train,attr)#求该属性下的属性值
infoall=info(train)#求该属性的信息量
for shuxing in attrtibutes:#对于每个属性值
attrtrain=parttrain(train,shuxing,attr)#先分理处该属性下的训练集
shuxinginfo=info(attrtrain)#求该训练集信息量
infoall[3]-=(shuxinginfo[0]/infoall[0])*shuxinginfo[3]#信息增益计算公式
if infoall[3]>=maxnum:#找到拥有最大信息增益的属性
maxnum=round(infoall[3],3)
maxattr=attr
return maxattr
def id3(examples,target,attributes):#id3方法
root=Tree(target)#定义根节点
examplesnum=info(examples)#先求训练集下的信息量
if examplesnum[1]!=0 and examplesnum[2]==0:#如果训练集下yes不为零然后no为零,则全为yes,返回
root.addchild(target,'yes')
elif examplesnum[1]==0 and examplesnum[2]!=0:
root.addchild(target,'no')
elif len(attributes)==0:
if examplesnum[1]>=examplesnum[2]:
root.addchild(target,'yes')
else:
root.addchild(target,'no')
else:
attrs=attrset(examples,target)#定义属性集
attributes.remove(target)
for attr in attrs:
nextexample=parttrain(examples,attr,target)
target2=maxinfo(nextexample,attributes)
xunhuanattrs=[]
for i in range(0,len(attributes)):
xunhuanattrs.append(attributes[i])
root.addchild(attr,id3(nextexample,target2,xunhuanattrs))
return root.show()
attrs=['outlook','temp','hum','wind']
target=maxinfo(train,attrs)
a=id3(train,target,attrs)
print(a)