还有个写得更为详细的资料可以参考一下:ID3算法实现,每一步都有很详细的计算。
下面两个图是自己以前看决策树的时候整的PPT里面的两页,主要是说的是ID3和相关概念,至于C4.5,CART,GBDT, RandomForest等内容就不贴上来了,ID3的原理还是很简单的,可以找到很多其他资料,这里呢主要还是侧重于编程实践。
下面代码的主要参考链接在上面已经给出了,详细的分析建议看原链接,这份代码应该算是最最最simple的一个例子了,没有剪枝和非离散化的数据处理,更没有不完整数据的处理,只是简单的构建几个数据实现ID3树的构建和分类决策。
# __author__ = 'czx'
# coding=utf-8
"""
Description:
ID3 Algorithm for fish classification task .
"""
from numpy import *
from math import log
def createData():
"""
:return:
data: including feature values and class values
labels: description of features
"""
data = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [1, 0, 'no'], [0, 1, 'no']]
labels = ['no surfing', 'flippers']
return data, labels
def calcShannonEnt( data ):
"""
:param
data: given data set
:return:
shannonent: shannon entropy of the given data set
"""
num = len(data) # data set size
labelcount = {} # count number of each class ['yes','no']
for sample in data:
templabel = sample[-1] # class label of sample in data ['yes','no']
if templabel not in labelcount.keys(): # add to dict
labelcount[templabel] = 0 # initial value is 0
labelcount[templabel] += 1 # count
shannonent = 0.0 # initial shannon entropy
for key in labelcount: # for all classes
prob = float(labelcount[key])/num # Pi = Ni/N
shannonent -= prob * log(prob, 2) # Ent = Addup(Pi) (i=2) ,bacause classes:['yes','no']
return shannonent # shannon entropy of the given data
def spiltDataSet(data,index,value):
"""
:param
data: the given data set
index: index of selected feature
value: the selected value to spilt the data set
:return:
resData: result of spilt data set
"""
resData = []
for sample in data: # for all samples in data set
# Mention that ID3 algorithm can only handle features with discrete values
if sample[index] == value: # the selected feature value of sample
spiltSample = sample[:index] # first index features
spiltSample.extend(sample[index+1:]) # last all features except feature[index]
resData.append(spiltSample)
return resData
def chooseBestFeatureToSpilt(data):
"""
:param
data: the given data set
:return:
bestFeature: the index of best feature which has best info gain
"""
num = len(data[0])-1 # all feature index [final column is class value]
baseEntropy = calcShannonEnt(data) # initial shannon entropy of biggest data set
bestInfoGain = 0.0
bestFeature = 0
for i in range(num): # all features index [0,1,2,...,n-2]
featList = [sample[i] for sample in data] # all features with index i in the data set
uniqueVals = set(featList) # Remove Duplicates
newEntropy = 0.0
for v in uniqueVals: # all values in features i
subData = spiltDataSet(data,i,v) # to spilt the data set with (index=i) and (value=v)
prob = len(subData)/float(len(data))
newEntropy += prob*calcShannonEnt(subData)
infoGain = baseEntropy - newEntropy
if infoGain>bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
"""
:param
classList: here are 'yes' or 'no'
:return:
sortedClassCount[0][0]: final voted class label with the largest number of each class
"""
classCount = {}
for v in classList:
if v not in classCount.keys():
classCount[v]=0
classCount[v]+=1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(data,labels):
"""
:param
data: the given data set
labels: description of features
:return:
myTree: final tree for making decision
"""
classList = [sample[-1] for sample in data] # class label ,here are ['yes','no']
if classList.count(classList[0]) == len(classList): # situation1: only one class
return classList[0]
if len(data[0])==0: # situation2: only one sample
return majorityCnt[classList]
bestFeature = chooseBestFeatureToSpilt(data) # choose best feature
bestFeatureLabel = labels[bestFeature] # label of best feature ,here is index :[0,1]
myTree = {bestFeatureLabel:{}} # dict to save myTree
featVals = [sample[bestFeature] for sample in data] # get all feature values from the best feature
uniqueVals = set(featVals)
for v in uniqueVals: # for each unique feature, here all possible values are [0,1]
subLabels = labels[:]
myTree[bestFeatureLabel][v] = createTree(spiltDataSet(data,bestFeature,v),subLabels)
return myTree
def classify(inputTree,featLabels,testSample):
"""
:param
inputTree: input tree
featLabels:
testSample:
:return:
"""
firstStr = inputTree.keys()[0] # root node of the tree : [0] means feature (eg:'flippers')
secondDict = inputTree[firstStr] # other sub trees in one dict
featIndex = featLabels.index(firstStr) # get index of the feature(current best feature) in root from features labels(description)
key = testSample[featIndex] # value of the best feature of test sample ( here is the key in dict )
featureVal = secondDict[key] # return class value
if isinstance(featureVal,dict): # can not decide which class the test sample belongs to
label = classify(featureVal, featLabels, testSample)
else:
label = featureVal # it belongs to 'label' , label in ['yes','no']
return label
# def storeTree(inputTree, filename):
# import pickle
# with open(filename, 'wb') as fw:
# pickle.dump(inputTree, fw)
#
# def grabTree(filename):
# import pickle
# fr = open(filename,'rb')
# return pickle.load(fr)
def test():
data, labels = createData()
myTree = createTree(data,labels)
print myTree
print classify(myTree,labels,[1,1])
if __name__ == '__main__':
test()
1:编程习惯还是得通过多编程多思考多总结慢慢提升,重要的事情说三遍,编程编程再编程。
2:从简单到复杂吧,要能更好更快地解决实际问题,就需要不断升入了解,像决策树相关的很多算法,C4.5,CART,GBDT, XGboost,RF等等。
去年年底想在新机器上装pycharm,结果各种激活码都不行,但是之前旧服务器上可以用就一直没装,最近旧服务器因机房温度过高得关掉一段时间,所以还得在新机器上装一个pycharm,这个用起来还是很舒服的。
学生可以用学校邮箱注册,免费试用pro pycharm一年,于是就整了个学生免费版的,用个一年再说。
忘记了学校邮箱密码很是尴尬,试过找回挺麻烦的,还好微信校园服务可以收到邮件,一通确认就拿到了学生免费版,哈哈。