机器学习实战(三)决策树ID3:树的构建和简单分类

说明:主要参考 机器学习实战之决策树,里面代码的实现和详细的注释,是一个关于机器学习实战的不错的学习资料,推荐一波。出于编程实践和机器学习算法梳理的目的,按照自己的代码风格重写该算法,在实现的过程中也很有助于自己的思考。为方便下次看时能快速理解便通过截图的方式截取了个人认为比较关键的内容,推荐看原链接,自己在代码实现过程中会留下一些思考,也欢迎交流学习。

还有个写得更为详细的资料可以参考一下:ID3算法实现,每一步都有很详细的计算。

相关知识点

下面两个图是自己以前看决策树的时候整的PPT里面的两页,主要是说的是ID3和相关概念,至于C4.5,CART,GBDT, RandomForest等内容就不贴上来了,ID3的原理还是很简单的,可以找到很多其他资料,这里呢主要还是侧重于编程实践。

机器学习实战(三)决策树ID3:树的构建和简单分类_第1张图片


机器学习实战(三)决策树ID3:树的构建和简单分类_第2张图片

ID3算法实现

下面代码的主要参考链接在上面已经给出了,详细的分析建议看原链接,这份代码应该算是最最最simple的一个例子了,没有剪枝和非离散化的数据处理,更没有不完整数据的处理,只是简单的构建几个数据实现ID3树的构建和分类决策。

#  __author__ = 'czx'
# coding=utf-8
"""
Description:
    ID3 Algorithm for fish classification task .
"""
from numpy import *
from math import log

def createData():
    """
    :return:
        data: including feature values and class values
        labels: description of features
    """
    data = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [1, 0, 'no'], [0, 1, 'no']]
    labels = ['no surfing', 'flippers']
    return data, labels

def calcShannonEnt( data ):
    """
    :param
        data: given data set
    :return:
        shannonent: shannon entropy of the given data set
    """
    num = len(data)                             # data set size
    labelcount = {}                             # count number of each class ['yes','no']
    for sample in data:
        templabel = sample[-1]                  # class label of sample in data ['yes','no']
        if templabel not in labelcount.keys():  # add to dict
            labelcount[templabel] = 0           # initial value is 0
        labelcount[templabel] += 1              # count
    shannonent = 0.0                            # initial shannon entropy
    for key in labelcount:                      # for all classes
        prob = float(labelcount[key])/num       # Pi = Ni/N
        shannonent -= prob * log(prob, 2)       # Ent = Addup(Pi) (i=2) ,bacause classes:['yes','no']
    return shannonent                           # shannon entropy of the given data

def spiltDataSet(data,index,value):
    """
    :param
        data: the given data set
        index: index of selected feature
        value: the selected value to spilt the data set
    :return:
        resData: result of spilt data set  
    """
    resData = []
    for sample in data:                           # for all samples in data set
        # Mention that ID3 algorithm can only handle features with discrete values 
        if sample[index] == value:                # the selected feature value of sample
            spiltSample = sample[:index]          # first index features
            spiltSample.extend(sample[index+1:])  # last all features except feature[index]
            resData.append(spiltSample)
    return resData

def chooseBestFeatureToSpilt(data):
    """
    :param
        data: the given data set
    :return:
        bestFeature: the index of best feature which has best info gain
    """
    num = len(data[0])-1                               # all feature index [final column is class value]
    baseEntropy = calcShannonEnt(data)                 # initial shannon entropy of biggest data set
    bestInfoGain = 0.0
    bestFeature = 0

    for i in range(num):                               # all features index [0,1,2,...,n-2]
        featList = [sample[i] for sample in data]      # all features with index i in the data set
        uniqueVals = set(featList)                     # Remove Duplicates
        newEntropy = 0.0
        for v in uniqueVals:                           # all values in features i
            subData = spiltDataSet(data,i,v)           # to spilt the data set with (index=i) and (value=v)
            prob = len(subData)/float(len(data))    
            newEntropy += prob*calcShannonEnt(subData) 
        infoGain = baseEntropy - newEntropy
        if infoGain>bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    """
    :param
        classList: here are 'yes' or 'no'
    :return:
        sortedClassCount[0][0]: final voted class label with the largest number of each class
    """
    classCount = {}
    for v in classList:
        if v not in classCount.keys():
            classCount[v]=0
        classCount[v]+=1
        sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

def createTree(data,labels):
    """
    :param
        data: the given data set
        labels:  description of features
    :return:
        myTree: final tree for making decision
    """
    classList = [sample[-1] for sample in data]           # class label ,here are ['yes','no']
    if classList.count(classList[0]) == len(classList):   # situation1: only one class
        return classList[0]
    if len(data[0])==0:                                   # situation2: only one sample
        return majorityCnt[classList]
    
    bestFeature = chooseBestFeatureToSpilt(data)          # choose best feature
    bestFeatureLabel = labels[bestFeature]                # label of best feature ,here is index :[0,1]
    
    myTree = {bestFeatureLabel:{}}                        # dict to save myTree
    featVals = [sample[bestFeature] for sample in data]   # get all feature values from the best feature
    uniqueVals = set(featVals)
    for v in uniqueVals:                                  # for each unique feature, here all possible values are [0,1]
        subLabels = labels[:]                             
        myTree[bestFeatureLabel][v] = createTree(spiltDataSet(data,bestFeature,v),subLabels)
    return myTree

def classify(inputTree,featLabels,testSample):
    """
    :param 
        inputTree: input tree 
        featLabels: 
        testSample: 
    :return: 
    """
    firstStr = inputTree.keys()[0]    # root node of the tree : [0] means feature (eg:'flippers')
    secondDict = inputTree[firstStr]  # other sub trees in one dict 

    featIndex = featLabels.index(firstStr)  # get index of the feature(current best feature) in root from features labels(description) 

    key = testSample[featIndex]             # value of the best feature of test sample ( here is the key in dict )
    featureVal = secondDict[key]            # return class value
    if isinstance(featureVal,dict):         # can not decide which class the test sample belongs to 
        label = classify(featureVal, featLabels, testSample)  
    else:
        label = featureVal                  # it belongs to 'label' , label in ['yes','no']
    return label

# def storeTree(inputTree, filename):
#     import pickle
#     with open(filename, 'wb') as fw:
#         pickle.dump(inputTree, fw)
# 
# def grabTree(filename):
#     import pickle
#     fr = open(filename,'rb')
#     return pickle.load(fr)

def test():
    data, labels = createData()
    myTree = createTree(data,labels)
    print myTree
    print classify(myTree,labels,[1,1])

if __name__ == '__main__':
    test()

小结

1:编程习惯还是得通过多编程多思考多总结慢慢提升,重要的事情说三遍,编程编程再编程。

2:从简单到复杂吧,要能更好更快地解决实际问题,就需要不断升入了解,像决策树相关的很多算法,C4.5,CART,GBDT, XGboost,RF等等。

Pycharm学生邮箱注册

去年年底想在新机器上装pycharm,结果各种激活码都不行,但是之前旧服务器上可以用就一直没装,最近旧服务器因机房温度过高得关掉一段时间,所以还得在新机器上装一个pycharm,这个用起来还是很舒服的。

学生可以用学校邮箱注册,免费试用pro pycharm一年,于是就整了个学生免费版的,用个一年再说。

机器学习实战(三)决策树ID3:树的构建和简单分类_第3张图片

忘记了学校邮箱密码很是尴尬,试过找回挺麻烦的,还好微信校园服务可以收到邮件,一通确认就拿到了学生免费版,哈哈。





















你可能感兴趣的:(机器学习实战,Python)