一、python代码
'''
Author: Vici__
date: 2020/5/18
'''
import math
import random
import numpy as np
from math import log
import operator
'''
1. 获取数据集
'''
def get_dataset():
dataset = [['青年', '否', '否', '中', '否'],
['青年', '否', '否', '高', '否'],
['青年', '是', '否', '高', '是'],
['青年', '是', '是', '中', '是'],
['青年', '否', '否', '中', '否'],
['中年', '否', '否', '中', '否'],
['中年', '否', '否', '高', '否'],
['中年', '是', '是', '高', '是'],
['中年', '否', '是', '很高', '是'],
['中年', '否', '是', '很高', '是'],
['老年', '否', '是', '很高', '是'],
['老年', '否', '是', '高', '是'],
['老年', '是', '否', '高', '是'],
['老年', '是', '否', '很高', '是'],
['老年', '否', '否', '中', '否']]
test_dataset = [['青年', '否', '否', '中'],
['中年', '否', '否', '高'],
['老年', '是', '否', '高']]
labels = ['年龄', '婚否', '车否', '身高']
class_label = '贷款'
return dataset, labels, class_label, test_dataset
'''
2. 计算信息熵
'''
def calc_shannonent(dataset):
n = len(dataset)
shannonent = 0.0
label_counts = {}
for data in dataset:
label = data[-1]
if label not in label_counts:
label_counts[label] = 1
else:
label_counts[label] += 1
for key in label_counts:
p = float(label_counts[key] / n)
shannonent -= p * log(p, 2)
return shannonent
'''
3. 分割数据集
'''
def split_dataset(dataset, axis, value):
result = []
for data in dataset:
if data[axis] == value:
r = data[:axis] + data[axis+1:]
result.append(r)
return result
'''
4. 获取最佳特征,计算信息增益率,并比较
'''
def choose_best_feature(dataset):
numFeatures = len(dataset[0]) - 1
ent = calc_shannonent(dataset)
best_gain = 0.0
best_feature_id = -1
infogain = 0.0
for i in range(numFeatures):
unique_values = set([line[i] for line in dataset])
new_ent = 0.0
s = 0.0
for value in unique_values:
sub_dataset = split_dataset(dataset, i, value)
prob = len(sub_dataset) / len(dataset)
new_ent += prob * calc_shannonent(sub_dataset)
s -= prob * log(prob, 2)
gain = ent - new_ent
infogain = gain / s
if infogain > best_gain:
best_gain = infogain
best_feature_id = i
return best_feature_id
'''
5. 创建决策树,算法主函数(以第一次分支为例)
'''
def createTree(dataset, labels):
class_list = [data[-1] for data in dataset]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
best_feature_id = choose_best_feature(dataset)
best_feature_label = labels[best_feature_id]
my_tree = {best_feature_label:{}}
del(labels[best_feature_id])
feature_values_set = set([data[best_feature_id] for data in dataset])
for value in feature_values_set:
temp = split_dataset(dataset, best_feature_id, value)
my_tree[best_feature_label][value] = createTree(temp, labels)
return my_tree
'''
6. 输入特征,进行预测
'''
def get_class(node, data, step=0):
if type(node) == str:
return node
for k1, v1 in node.items():
for k2, v2 in v1.items():
if data[k1] == k2:
res = get_class(v2, data, step+1)
if res != None:
return res
def predict(test_dataset, my_tree, labels):
print("使用决策树预测:")
for i in range(len(test_dataset)):
data = {}
for j in range(len(test_dataset[i])):
data[labels[j]] = test_dataset[i][j]
res = get_class(my_tree, data)
print(test_dataset[i], res)
'''
7. 打印树
'''
def show_result(node, class_label, step=0):
if type(node) == str:
print('-'*step + class_label + ":" + node)
return
for k1, v1 in node.items():
for k2, v2 in v1.items():
print('-'*step + k1 + ':' + k2)
show_result(v2, class_label, step+1)
dataset, labels, class_label, test_dataset = get_dataset()
tmp_labels = labels[:]
my_tree = createTree(dataset, tmp_labels)
print("打印决策树:")
show_result(my_tree, class_label)
print("---------------------------------")
predict(test_dataset, my_tree, labels)
二、测试
要求:

结果:
