算法(二)决策树

使用axis的目的是去掉最优索引以及与value做比较分类子集,使用value的目的是分类子集

每一棵树都是用最优索引以及value进行分类得到的子集

传入索引和value的目的就是分解数据集

背景:决策树算法用于条件分类,在多种情景模式下产生不同的结果,针对情景的一种预测
方法:采用信息增益进行决策
信息增益公式:整个数据集的类别-plog(p,2)求和 - 每种情景下不同类别的-plog(p,2)求和乘以这种情景的这种类别的权重
决策树构建方式:树根是全部数组,树干是利用信息增益,最大信息增益的索引值,和这个索引这的全部属性并利用这个索引值和每个属性分离数组,利用数据构建枝干
主要函数:求香浓熵,求最大信息增益,构建决策树,分离数组

代码如下:

import operator
import math
# a = {'a':0,'b':3,'c':1}
# b = sorted(a.items(),key=operator.itemgetter(1),reverse=True)
# print(b[0][0])
# 1 计算香浓熵
# 2 某一条件计算香浓条件熵
# 3 某一条件计算信息增益
# 4 利用信息增益创建决策树

# 数据集
dataset=[[0, 0, 1, 0, 'no'],
        [0, 0, 1, 1, 'no'],
        [0, 1, 1, 1, 'yes'],
        [0, 1, 0, 0, 'yes'],
        [0, 0, 1, 0, 'no'],
        [1, 0, 1, 0, 'no'],
        [1, 0, 1, 1, 'no'],
        [1, 1, 0, 1, 'yes'],
        [1, 0, 0, 2, 'yes'],
        [1, 0, 0, 2, 'yes'],
        [2, 0, 0, 2, 'yes'],
        [2, 0, 0, 1, 'yes'],
        [2, 1, 1, 1, 'yes'],
        [2, 1, 1, 2, 'yes'],
        [2, 0, 1, 0, 'no']]

# 计算香浓熵,香浓熵分两次遍历,第一次遍历获取每种类别数量,第二次遍历,获取概率
def computeExp(dataset):
    # 输出香浓熵
    result = 0
    # 获取类别数量
    num_class = list(set([data[-1] for data in dataset]))
    len_dict = {}
    for data in dataset:
        if data[-1] not in len_dict.keys():
            len_dict[data[-1]] = 0
        len_dict[data[-1]] += 1
    for i in range(len(num_class)):
        port = len_dict[num_class[i]]/len(dataset)
        result -= port*math.log(port,2)
    return result

def splitList(dataset,axis,label):
    data_dict = []
    for data in dataset:
        if data[axis] == label:
            data_dict.append(data[:axis] + data[axis + 1:])
    return data_dict

# 对输入维度计算信息增益
def increment(dataset,axis):
    num_label = list(set([data[axis] for data in dataset]))
    base_exp = computeExp(dataset)
    # 香浓条件熵
    data_exp = 0
    for label in num_label:
        data_dict = splitList(dataset,axis,label)
        port = len(data_dict)/len(dataset)
        data_exp += port*computeExp(data_dict)
    return base_exp - data_exp

def getBestFeatrue(dataset):
    # 获取类别数目
    num_featrue = len(dataset[0][:-1])
    # 最大信息增益
    max_exp = 0
    # 最大信息增益对应的索引值
    max_index = -1

    for i in range(num_featrue):
        new_exp = increment(dataset, i) if increment(dataset, i) > max_exp else max_exp
        if new_exp != max_exp:
            max_index = i
            max_exp = new_exp
    return max_index

def getResult(classes):
    compare_dict = {}
    for label in classes:
        if label not in compare_dict.keys():
            compare_dict[label] = 0
        compare_dict[label] += 1
    return sorted(compare_dict.items(), key=operator.itemgetter(1), reverse=True)[0][0]

def createTree(dataset):
    bestfeatrue = getBestFeatrue(dataset)
    if len(list(set([data[-1] for data in dataset])))==1:
        return dataset[0][-1]
    if(len(dataset[0])==1):
        return getResult(dataset)
    mytree = {bestfeatrue:{}}
    num_class = list(set([data[bestfeatrue] for data in dataset]))
    for label in num_class:
        if label not in mytree[bestfeatrue].keys():
            mytree[bestfeatrue][label] = {}
        new_dataset = splitList(dataset,bestfeatrue,label)
        mytree[bestfeatrue][label] = createTree(new_dataset)
    return mytree

print(createTree(dataset))

你可能感兴趣的:(全部覆盖)