from math import log
import operator
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no'],
]
labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']
return dataSet, labels
# 创建树模型
def creatTree(dataset, labels, featLabels):
classList = [example[-1] for example in dataset]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1: # 不理解这里为什么是dataset[0],len(dataset[0])代表计算数据集的长度
return majorityCnt(classList)
# 以上为终止条件
bestFeat = chooseBeatFeatureToSplit(dataset) # 选择最好的特征
bestFeatLabel = labels[bestFeat] # 找到最好特征对应的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel: {}}
del labels[bestFeat]
featValue = [example[bestFeat] for example in dataset]
uniqueVals = set(featValue)
for value in uniqueVals:
sublabels = labels[:]
myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataset, bestFeat, value), sublabels, featLabels)
return myTree
# 计算当前节点中哪个类别比较多
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 计算当前数据集的最优特征
def chooseBeatFeatureToSplit(dataset):
numFeatures = len(dataset[0]) - 1
# 计算结果列的熵值
baseEntropy = calcShannonEnt(dataset)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
festList = [example[i] for example in dataset]
uniqueVals = set(festList)
newEntropy = 0
for val in uniqueVals:
subDataSet = splitDataSet(dataset, i, val)
prop = len(subDataSet) / float(len(dataset))
newEntropy += prop * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def splitDataSet(dataset, axis, val):
retDataSet = []
for featVec in dataset:
if featVec[axis] == val:
reducedFeatVec = featVec[: axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 计算初始熵值
def calcShannonEnt(dataset):
numExamples = len(dataset)
labelCounts = {}
for featVec in dataset:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0
for key in labelCounts:
prop = float(labelCounts[key]) / numExamples
shannonEnt -= prop * log(prop, 2)
return shannonEnt
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = creatTree(dataSet, labels, featLabels)