代码人生的记忆---2018-07-09

# -*- coding:utf-8 -*-

from math import log

import operator

import matplotlib.pyplot as plt

def calc(dataset):

    numentries = len(dataSet)

    labelscounts = {}

    for featVec in dataset:

        currentlabel = featVec[-1]

        if currentlabel not in labelscounts.keys():

            labelscounts[currentlabel] = 0

        labelscounts[currentlabel] += 1

    shannon = 0.0

    for key in labelscounts:

        prob = float(labelscounts[key])/numentries

        shannon -= prob*log(prob,2)

    return shannon

def creatDataSet():

    dataSet = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]

    labels = ['no surfacing', 'flippers']

    return dataSet, labels

dataSet, labels = creatDataSet()

shannon = calc(dataSet)

print(shannon)

#按照给定的特征划分数据集

def splitdataset(dataSet, axis, value):

    retdataset = []

    for featVec in dataSet:

        if featVec[axis] == value:

            reducedFeatVec = featVec[:axis]

            reducedFeatVec.extend(featVec[axis+1:])

            retdataset.append(reducedFeatVec)

    return retdataset

dataSet, labels = creatDataSet()

retdataset = splitdataset(dataSet, 0, 1)

print(retdataset)

retdataset1 = splitdataset(dataSet, 1, 1)

print(retdataset1)

#定义选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet):

    numf = len(dataSet[0]) - 1

    baseE = calc(dataSet)

    bestI = 0.0

    bestF = -1

    for i in range(numf):

        featlist = [example[i] for example in dataSet]

        uniqueVals = set(featlist)

        newE = 0.0

        for value in uniqueVals:

            subdataset = splitdataset(dataSet, i, value)

            prob = len(subdataset)/float(len(dataSet))

            newE += prob * calc(subdataset)

        info = baseE - newE

        if (info > bestI):

            bestI = info

            bestF = i

    return bestF

dataSet, labels = creatDataSet()

a = chooseBestFeatureToSplit(dataSet)

print(a)

def maj(classlist):

    classcount={}

    for vote in classlist:

        if vote not in classcount.keys():classcount[vote] = 0

        classcount[vote] += 1

    sortedclasscount = sorted(classcount.iteritems(), key = operator.itemgetter(1),reverse=True)

    #reverse默认为升序,定义为True时,为降序

    return sortedclasscount[0][0]

#创建树的函数代码

def createtree(dataSet, labels):

    classlist = [example[-1] for example in dataSet]

    if classlist.count(classlist[0]) == len(classlist):

        return classlist[0]

    if len(dataSet[0]) == 1:

        return maj(classlist)

    bestF = chooseBestFeatureToSplit(dataSet)

    bestFlabel = labels[bestF]

    myTree = {bestFlabel:{}}   

    del(labels[bestF])

    featValues = [example[bestF] for example in dataSet]

    uniqueVals = set(featValues)

    for value in uniqueVals:

        sublabels = labels[:]

        myTree[bestFlabel][value] = createtree(splitdataset(dataSet, bestF, value), sublabels)

    return myTree

dataSet, labels = creatDataSet()

c = createtree(dataSet, labels)

print(c)

decisionNode = dict(boxstyle='sawtooth', fc='0.8')

leafNode = dict(boxstyle='round4', fc='0.8')

arrow_args = dict(arrowstyle='<-')

def plotNode(nodeText, centerPt, parentPt, nodeType):

    createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction',

    xytext = centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)

def createPlot():

    fig = plt.figure(1, facecolor="white")

    fig.clf()

    createPlot.ax1 = plt.subplot(111, frameon=False)

    plotNode(U'a decision node', (0.5,0.1), (0.1,0.5), decisionNode)

    plotNode(U'a leaf node', (0.8,0.1), (0.3,0.8), leafNode)

    plt.show()

createPlot()

def getnumleafs(myTree):

    numleafs = 0

    firststr = myTree.keys()[0]

    seconddict = myTree[firststr]

    for key in seconddict.keys():

        if type(seconddict[key]).__name__ =='dict':

            numleafs += getnumleafs(seconddict[key])

        else:

            numleafs += 1

    return numleafs

def gettreedepth(myTree):

    maxdepth = 0

    firststr = myTree.keys()[0]

    seconddict = myTree[firststr]

    for key in seconddict.keys():

        if type(seconddict[key]).__name__=='dict':

            thisdepth = 1 + gettreedepth(seconddict[key])

        else:

            thisdepth = 1

        if thisdepth > maxdepth:

            maxdepth = thisdepth

    return maxdepth

你可能感兴趣的:(代码人生的记忆---2018-07-09)