决策树在周志华的西瓜书里面已经介绍的很详细了(西瓜书P73-P79),那也是我看过讲的最清楚的决策树讲解了,我这里就不献丑了,这篇文章主要是分享决策树的代码。
在西瓜书中介绍了三种决策树,分别为ID3,C4.5和CART三种决策树,三种树出了分裂的计算方法不一样之外,其余的都一样,大家可以多看看书,如果有什么不清楚的可以看看我的代码,决策树的代码算是很简单的了,我有朋友面试的时候就被要求写决策树代码。话不多说了,下面开始吧。
本篇文章的数据来自于西瓜中的西瓜数据,也是我纯手打下来的。下面一起看看代码吧
import pandas as pd
import numpy as np
from math import log
dataset=pd.read_csv('西瓜数据集.csv',encoding='utf-8')
导入所需要的包,本人比较懒直接用Pandas库来实现了,用这个库的话很方便。
def calShanEnt(dataset,col):
tarset=set(dataset[col])
res=0
for i in tarset:
pi=np.sum(dataset[col] == i)/len(dataset)
res=res-pi* log(pi, 2)
return res
第一个函数,计算香农值。
def splitData(Data,fea,value):
res=Data[Data[fea]==value].copy()
res=res.drop(fea,axis=1)
return res
第二个函数,分裂函数
def ID3(value_set,dataset,fea):
baseEnt = calShanEnt(dataset, "target")
newEnt = 0
for v in value_set:
newEnt += np.sum(dataset[fea] == v) / len(dataset) * calShanEnt(dataset[dataset[fea] == v],"target")
return baseEnt-newEnt
ID3树的计算函数
def C4_5(value_set,dataset,fea):
gain=ID3(value_set,dataset,fea)
IVa=calShanEnt(dataset,fea)
return gain/IVa
C4_5树的计算函数,从这个函数中可以很清楚的看出,ID3树和C4_5树分裂函数的区别。
def Gini(dataset,col):
tarset = set(dataset[col])
gini=1
for i in tarset:
gini=gini-(np.sum(dataset[col] == i)/len(dataset))**0.5
return gini
def CART(value_set,dataset,fea):
Gini_index = 0
for v in value_set:
Gini_index += np.sum(dataset[fea] == v) / len(dataset) * Gini(dataset[dataset[fea] == v],"target")
return Gini_index
CART树的分裂计算函数,上面是Gini指数的计算,下面是计算从该特征分裂的收益。
def chooseBestFea(dataset):
features=[i for i in dataset.columns if i!='target']
bestFet=features[0]
bestInfoGain=-1
for fea in features:
value_set=set(dataset[fea])
# gain=ID3(value_set,dataset,fea)##这是调用ID3增益的函数
# gain=C4_5(value_set,dataset,fea)
gain=CART(value_set,dataset,fea)
if gain>bestInfoGain:
bestInfoGain=gain
bestFet=fea
return bestFet
分裂特征搜索,可以在这里设置使用哪个树(ID3,C4_5或者CART)进行建模
def creatTree(dataset):
if len(dataset.columns)==1:
return dataset['target'].value_counts().index[0]
if len(set(dataset['target']))==1:
return list(dataset['target'])[0]
bestFea=chooseBestFea(dataset)
myTree={bestFea:{}}
for i in set(dataset[bestFea]):
myTree[bestFea][i]=creatTree(splitData(dataset,bestFea,i))
return myTree
T=creatTree(dataset)
建立一个决策树的主流程。本次关于决策树的代码分享到这里,关注challengeHub公众号回复决策树即可获得西瓜数据和完整的决策树代码。
想了解更多算法,数据分析,数据挖掘等方面的知识,欢迎关注ChallengeHub公众号,添加微信进入微信交流群,或者进入QQ交流群。关注公众号或者进群可以直接获得上面的代码,而且也有很多机器学习的资料可以获取。