使用axis的目的是去掉最优索引以及与value做比较分类子集,使用value的目的是分类子集
每一棵树都是用最优索引以及value进行分类得到的子集
传入索引和value的目的就是分解数据集
背景:决策树算法用于条件分类,在多种情景模式下产生不同的结果,针对情景的一种预测
方法:采用信息增益进行决策
信息增益公式:整个数据集的类别-plog(p,2)求和 - 每种情景下不同类别的-plog(p,2)求和乘以这种情景的这种类别的权重
决策树构建方式:树根是全部数组,树干是利用信息增益,最大信息增益的索引值,和这个索引这的全部属性并利用这个索引值和每个属性分离数组,利用数据构建枝干
主要函数:求香浓熵,求最大信息增益,构建决策树,分离数组
代码如下:
import operator
import math
# a = {'a':0,'b':3,'c':1}
# b = sorted(a.items(),key=operator.itemgetter(1),reverse=True)
# print(b[0][0])
# 1 计算香浓熵
# 2 某一条件计算香浓条件熵
# 3 某一条件计算信息增益
# 4 利用信息增益创建决策树
# 数据集
dataset=[[0, 0, 1, 0, 'no'],
[0, 0, 1, 1, 'no'],
[0, 1, 1, 1, 'yes'],
[0, 1, 0, 0, 'yes'],
[0, 0, 1, 0, 'no'],
[1, 0, 1, 0, 'no'],
[1, 0, 1, 1, 'no'],
[1, 1, 0, 1, 'yes'],
[1, 0, 0, 2, 'yes'],
[1, 0, 0, 2, 'yes'],
[2, 0, 0, 2, 'yes'],
[2, 0, 0, 1, 'yes'],
[2, 1, 1, 1, 'yes'],
[2, 1, 1, 2, 'yes'],
[2, 0, 1, 0, 'no']]
# 计算香浓熵,香浓熵分两次遍历,第一次遍历获取每种类别数量,第二次遍历,获取概率
def computeExp(dataset):
# 输出香浓熵
result = 0
# 获取类别数量
num_class = list(set([data[-1] for data in dataset]))
len_dict = {}
for data in dataset:
if data[-1] not in len_dict.keys():
len_dict[data[-1]] = 0
len_dict[data[-1]] += 1
for i in range(len(num_class)):
port = len_dict[num_class[i]]/len(dataset)
result -= port*math.log(port,2)
return result
def splitList(dataset,axis,label):
data_dict = []
for data in dataset:
if data[axis] == label:
data_dict.append(data[:axis] + data[axis + 1:])
return data_dict
# 对输入维度计算信息增益
def increment(dataset,axis):
num_label = list(set([data[axis] for data in dataset]))
base_exp = computeExp(dataset)
# 香浓条件熵
data_exp = 0
for label in num_label:
data_dict = splitList(dataset,axis,label)
port = len(data_dict)/len(dataset)
data_exp += port*computeExp(data_dict)
return base_exp - data_exp
def getBestFeatrue(dataset):
# 获取类别数目
num_featrue = len(dataset[0][:-1])
# 最大信息增益
max_exp = 0
# 最大信息增益对应的索引值
max_index = -1
for i in range(num_featrue):
new_exp = increment(dataset, i) if increment(dataset, i) > max_exp else max_exp
if new_exp != max_exp:
max_index = i
max_exp = new_exp
return max_index
def getResult(classes):
compare_dict = {}
for label in classes:
if label not in compare_dict.keys():
compare_dict[label] = 0
compare_dict[label] += 1
return sorted(compare_dict.items(), key=operator.itemgetter(1), reverse=True)[0][0]
def createTree(dataset):
bestfeatrue = getBestFeatrue(dataset)
if len(list(set([data[-1] for data in dataset])))==1:
return dataset[0][-1]
if(len(dataset[0])==1):
return getResult(dataset)
mytree = {bestfeatrue:{}}
num_class = list(set([data[bestfeatrue] for data in dataset]))
for label in num_class:
if label not in mytree[bestfeatrue].keys():
mytree[bestfeatrue][label] = {}
new_dataset = splitList(dataset,bestfeatrue,label)
mytree[bestfeatrue][label] = createTree(new_dataset)
return mytree
print(createTree(dataset))