【机器学习实战】决策树

算法思路

     在构造决策树时,第一个需要解决的问题就是,如何确定出哪个特征在划分数据分类是起决定性作用,或者说使用哪个特征分类能实现最好的分类效果。这样,为了找到决定性的特征,划分得到最好的结果,我们就需要评估每个特征。当找到最优特征后,依此特征,数据集就被划分为几个数据子集,这些数据自己会分布在该决策点的所有分支中。此时,如果某个分支下的数据属于同一类型,则该分支下的数据分类已经完成,无需进行下一步的数据集分类;如果分支下的数据子集内数据不属于同一类型,那么就要重复划分该数据集的过程,按照划分原始数据集相同的原则,确定出该数据子集中的最优特征,继续对数据子集进行分类,直到所有的特征已经遍历完成,或者所有叶结点分支下的数据具有相同的分类。

计算信息熵

     划分数据集的大原则是:将无序的数据变得更加有序。在划分数据集前后信息发生的变化称为信息增益,如果我们知道如何计算信息增益,就可以计算每个特征值划分数据集获得的信息增益,而获取信息增益最高的特征就是最好的特征。

     
信息熵计算公式.png

     代码实现时,p代表类标签出现的概率。
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

划分数据集

     实现功能:选择符合条件的项(去掉用于分类的该特征)输入参数:待划分的数据集、划分数据集的特征和返回值。注意:在划分数据集函数中,传递的参数dataSet列表的引用,在函数内部对该列表对象进行修改,会导致列表内容发生改变,于是,为了消除影响,应该重新建立一个列表对象。

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

选择最好的数据集划分方式

     首先计算整个数据集的信息熵,分别比较利用不同的特征划分数据集信息熵大小的增益情况。假设划分成A,B子集,那么划分后的数据集信息熵是子集A,B信息熵的加权平均和,那么信息增益最大的特征即为划分的最佳特征。

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                      #returns an integer

构建决策树

     构建决策树的工作原理为:首先得到原始数据集,然后基于最好的属性划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将向下传递到树分支的下一个结点,在该结点上,我们可以再次划分数据。因此,我们可以采用递归的方法处理数据集,完成决策树构造。
     递归的条件是:程序遍历完所有划分数据集的属性,或者每个分之下的所有实例都具有相同的分类。如果所有的实例具有相同的分类,则得到一个叶子结点或者终止块。
     当然,我们可能会遇到,当遍历完所有的特征属性,但是某个或多个分支下实例类标签仍然不唯一,此时,我们需要确定出如何定义该叶子结点,在这种情况下,通过会采取多数表决的原则选取分支下实例中类标签种类最多的分类作为该叶子结点的分类。表决和构造决策树的代码分别如下:

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): 
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree                            

你可能感兴趣的:(【机器学习实战】决策树)