import math
def createDataSet():
dataset = [['青年', '否', '否', '一般', '否'],
['青年', '否', '否', '好', '否'],
['青年', '是', '否', '好', '是'],
['青年', '是', '是', '一般', '是'],
['青年', '否', '否', '一般', '否'],
['中年', '否', '否', '一般', '否'],
['中年', '否', '否', '好', '否'],
['中年', '是', '是', '好', '是'],
['中年', '否', '是', '非常好', '是'],
['中年', '否', '是', '非常好', '是'],
['老年', '否', '是', '非常好', '是'],
['老年', '否', '是', '好', '是'],
['老年', '是', '否', '好', '是'],
['老年', '是', '否', '非常好', '是'],
['老年', '否', '否', '一般', '否']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
return dataset, labels
def calcEntropy(dataset):
"""
:return:
"""
num = len(dataset)
labelCounts = {}
for data in dataset:
currLabel = data[-1]
if currLabel not in labelCounts:
labelCounts[currLabel] = 0
labelCounts[currLabel] += 1
entropy = 0
for key in labelCounts:
prob = float(labelCounts[key]) / num
entropy += -1 * prob * math.log(prob, 2)
return entropy
def splitDataSet(dataset, axis, value):
"""
:param dataset:
:param axis:
:param value:
:return:
"""
ret = []
for data in dataset:
if data[axis] == value:
reduceFeatVec = data[:axis]
reduceFeatVec.extend(data[axis + 1:])
ret.append(reduceFeatVec)
return ret
def chooseBestFeature(dataset):
"""
:param dataset:
:return:
"""
numFeatures = len(dataset[0]) - 1
baseEntropy = calcEntropy(dataset)
bestInfoGain = 0
bestFeature = 0
for i in range(numFeatures):
features = [example[i] for example in dataset]
uniqueVals = set(features)
newEntropy = 0
for value in uniqueVals:
subDataSet = splitDataSet(dataset, i, value)
prob = len(subDataSet) / float(len(dataset))
newEntropy += prob * calcEntropy(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def createTree(dataset, labels):
"""
:param dataset:
:param labels:
:return:
"""
classList = [example[-1] for example in dataset]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1:
return classList
bestFeat = chooseBestFeature(dataset)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataset]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet \
(dataset, bestFeat, value), subLabels)
return myTree
if __name__ == '__main__':
dataset, labels = createDataSet()
print(createTree(dataset, labels))