from math import log import operator def calcShannonEnt(dataSet): numEntries=len(dataSet); lableCounts={}; for featVec in dataSet: currentLabel=featVec[-1]; if currentLabel not in lableCounts.keys(): lableCounts[currentLabel]=0; lableCounts[currentLabel]+=1; shannonEnt=0.0; for key in lableCounts: prob= float(lableCounts[key])/numEntries; shannonEnt-=prob* log(prob,2); return shannonEnt; def createDataSet(): dataSet=[[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels=['no surfacing','flippers'] return dataSet,labels; def splitDataSet(dataSet,axis,value): retDataSet=[]; for featVec in dataSet: if featVec[axis]== value: reduceFeatVec=featVec[:axis]; reduceFeatVec.extend(featVec[axis+1:]); retDataSet.append(reduceFeatVec); return retDataSet; def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1; baseEntropy=calcShannonEnt(dataSet); bestInfoGain=0.0;bestFeature=-1; for i in range(numFeatures): featList=[example[i] for example in dataSet]; uniqueVals=set(featList); newEntropy=0.0; for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value); prob=len(subDataSet)/float(len(dataSet)); newEntropy+=prob*calcShannonEnt(subDataSet); infoGain=baseEntropy-newEntropy; if(infoGain>bestInfoGain): bestInfoGain=infoGain; bestFeature=i; return bestFeature; def majorityCnt(classList): classCount={}; for vote in classList: if vote not in classCount.keys(): classCount[vote]=0; classCount[vote]+=1; sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True); return sortedClassCount[0][0]; def createTree(dataSet,labels): classList=[example[-1] for example in dataSet]; if classList.count(classList[0])==len(classList): return classList[0]; if len(dataSet[0])==1: return majorityCnt(classList); bestFeat=chooseBestFeatureToSplit(dataSet); bestFeatLabel=labels[bestFeat]; myTree={bestFeatLabel:{}}; del(labels[bestFeat]); featValues=[example[bestFeat] for example in dataSet]; uniqueVals=set(featValues); for value in uniqueVals: subLabels=labels[:]; myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels); return myTree;