决策树 学习

特点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型。

代码示例

# -*- coding: utf-8 -*-
# __author__ = 'wangbowen'

from math import log
import operator


def calcShannonEnt(data_set):
    """ 计算给定数据集的熵 """
    num_entries = len(data_set)
    label_count = {}

    # 为所有可能的分类创建字典
    for val in data_set:
        curr_label = val[-1]
        if curr_label not in label_count:
            label_count[curr_label] = 1
        else:
            label_count[curr_label] += 1

    shannon_ent = 0.0
    for key in label_count:
        prob = float(label_count[key]) / num_entries
        shannon_ent -= prob * log(prob, 2)

    return shannon_ent


def createDataSet():
    data_set = [[1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no']]

    labels = ['no surfacing', 'flippers']

    return data_set, labels


def splitDataSet(data_set, axis, value):
    """ 按照给定的特征划分数据集"""
    # 创建新的list对象
    # PS:Python语言在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期。

    ret_data_set = []
    for val in data_set:
        # 将符合要求的元素抽取出来
        if val[axis] == value:
            reduced_val = val[:axis]
            reduced_val.extend(val[axis + 1:])
            ret_data_set.append(reduced_val)

    return ret_data_set


def chooseBestSplit(data_set):
    """ 选择最好的数据集划分方式 """
    num = len(data_set[0]) - 1  # 最后一列是记录label
    base_ent = calcShannonEnt(data_set)  # 原始的香农熵
    best_gain = 0.0
    best_feature = -1

    for i in range(num):
        # 创建唯一的分类标签列表
        feat_list = [x[i] for x in data_set]
        uniq_vals = set(feat_list)
        new_ent = 0.0

        # 对每个特征划分一次数据集,然后计算新数据集的新熵值
        # 计算每种划分方式的信息熵
        for val in uniq_vals:
            sub_data_set = splitDataSet(data_set, i, val)
            prob = len(sub_data_set) / float(len(data_set))

            # 对所有唯一特征的熵求和
            new_ent += prob*calcShannonEnt(sub_data_set)

        info_gain = base_ent - new_ent  # 信息增益是熵的减少
        if info_gain > best_gain:
            best_gain = info_gain
            best_feature = i

    return best_feature


def mayorityCnt(class_list):
    """ 返回出现次数最多的分类名称 """
    class_count = {}

    for vote in class_list:
        if vote not in class_count:
            class_count[vote] = 0
        else:
            class_count[vote] += 1

    sorted_class_count = sorted(class_count.iteritems(),
                                key=operator.itemgetter(1),
                                reverse=True)
    return sorted_class_count[0][0]


def createTree(data_set, labels):
    print 'create tree', data_set, labels
    class_list = [x[-1] for x in data_set]

    # 类别完全相同 则停止继续划分
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]

    # 遍历完所有特征时,返回出现最多的
    if len(data_set[0]) == 1:
        return mayorityCnt(class_list)

    best_feat = chooseBestSplit(data_set)
    best_label = labels[best_feat]

    my_tree = {best_label: {}}
    del(labels[best_feat])

    feat_values = [x[best_feat] for x in data_set]
    uniq_values = set(feat_values)

    for val in uniq_values:
        sub_labels = labels[:]
        my_tree[best_label][val] = createTree(splitDataSet(data_set, best_feat, val),
                                              sub_labels)

    return my_tree


if __name__ == '__main__':
    data_set, labels = createDataSet()
    tree = createTree(data_set, labels)
    print tree

你可能感兴趣的:(数据挖掘)