决策树之构建《机器学习实战》-py3.5

# -*- coding: utf-8 -*-
"""
Created on Tue Jan 30 09:48:53 2018
Email: [email protected]
@author: DidiLv
Python version: 3.5
"""

from math import *
import operator

def createDataSet():
    dataSet = [[1, 1, "yes"],
               [1, 1, "yes"],
               [1, 0, "no"],
               [0, 1, "no"],       
               [0, 1, "no"]]
    labels = ["no surfacing", "flippers"]
    return dataSet, labels



def calShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {} 
    for featVec in dataSet:
        currentLabel = featVec[-1] # "-1": depends on the data structure
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= log(prob, 2) * prob
    return shannonEnt

def splitDataSet(dataSet, axis, value):
    # axis: the feature index of dataSet 
    # value: the 'axis'th feature value
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            # note that it's slice operation in python
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            # make clear about the difference of "extend" and "append"
            retDataSet.append(reducedFeatVec)
    return retDataSet

# important 1: 
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calShannonEnt(dataSet) # the base entropy for comparision 
    bestInfoGain = 0.0; 
    bestFeature = -1
    # create a subdataSet to compute the shannon entropy
    for i in range(numFeatures):
        # step 1: extract the ith feature
        featList = [example[i] for example in dataSet]
        # step 2: "set" the related feature values for "classification"
        uniqueVals = set(featList)
        newEntropy = 0.0
        # step 3: calculate the shannon entropy for subdataSet
        for value in uniqueVals:
            # step 3.1: classification
            subDataSet = splitDataSet(dataSet, i, value)
            # step 3.2: calculation
            prob = float(len(subDataSet)) / len(dataSet)
            newEntropy += prob * calShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
        
def majorityCnt(classList):
    classCount = {}
    for vote in classCount:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]    

def createTree(dataSet, labels):
    # stop criterion
    # 1. NO other class
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(dataSet):
        return classList[0]
    # 2. NO feature
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # create the Decesion Tree
    bestFeatindex = chooseBestFeatureToSplit(dataSet) # return the index of the best feature 
    bestFeatLabel = labels[bestFeatindex]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeatindex])
    bestfeatValues = [example[bestFeatindex] for example in dataSet]
    uniqueVals = set(bestfeatValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeatindex, value), subLabels)
    return myTree
        
        


def main():
    myDat, labels = createDataSet()
    
    print("Project1: calShannonEnt: -->")
    shannonEnt = calShannonEnt(myDat)
    print("ShannonEntropy =", shannonEnt)
    
    print("Project2: splitDataSet: -->")
    splitData = splitDataSet(myDat, 0, 1)
    print(splitData)    
    
    print("Project3: chooseBestFeatureToSplit: -->")
    bestFeature = chooseBestFeatureToSplit(myDat)
    print("The best feature for myDat is: %d" %bestFeature)
    
    print("Project4: createTree: -->")
    myTree = createTree(myDat, labels)
    print("My_Tree is: ", myTree)
    
    
    
    
    
    
if __name__ == "__main__":
    main()

你可能感兴趣的:(算法,机器学习,数据挖掘)