亲手实现决策树(一)

决策树的建立

1.整体思路

准备函数

  • 依据某个feature对数据进行分割为set_1, set_2 --> divide_set
  • 分别对set_1, set_2的分类结果进行统计 --> unique_count
  • 根据统计的结果计算交叉熵 --> entropy

计算思路

  • 对数据的列进行for循环,选择出gain最大的feature
  • 根据此feature进行数据集的分割,然后再对set_1, set_2进行递归
  • 直至gain为0或要进一步判断的子数据集为空

2.python实现

主干代码

def build_tree(rows, scoref=entropy):
    # 基准情况
    if len(rows) == 0:
        return DecisionNode()

    current_score = scoref(rows)      # 分类前的得分
    best_gain = 0.0
    best_criteria = None
    best_sets = None

    column_count = len(rows[0]) - 1    # 特征数量
    for col in range(column_count):
        # 在当前列中生成一个由不同值构成的序列
        column_values = {}
        for row in rows:
            column_values[row[col]] = 1
        # 分类
        for value in column_values.keys():
            set_1, set_2 = divide_set(rows, col, value)

            p = float(len(set_1)) / len(rows)
            gain = current_score - p * scoref(set_1) - (1 - p) * scoref(set_2)
            if gain > best_gain and len(set_1) > 0 and len(set_2) > 0:
                best_gain = gain
                best_criteria = (col, value)
                best_sets = (set_1, set_2)
        # 创建子分支
    if best_gain > 0:
        # 不是叶子结点,继续递归分类,分类结果res=None, 判断条件(特征)为col,临界值为value
        true_branch = build_tree(best_sets[0])
        false_branch = build_tree(best_sets[1])
        return DecisionNode(col=best_criteria[0], value=best_criteria[1], tb=true_branch, fb=false_branch)
    else:
        # 不能再分类,返回分类的计数结果
        return DecisionNode(results=unique_counts(rows))

DecisionNode类

class DecisionNode:
    def __init__(
            self,
            col=-1,
            value=None,
            results=None,
            tb=None,
            fb=None
    ):
        self.col = col              # the criteria to be tested
        self.value = value          # true value
        self.results = results      # 分类结果,非叶子结点均为None
        self.tb = tb                # true
        self.fb = fb                # false

divide_set分割数据

def divide_set(rows, column, value):
    # 根据value对数据进行2分类,set_1中为true, set_2中为false
    split_function = None
    if isinstance(value, int) or isinstance(value, float):
        split_function = lambda row: row[column] >= value
    else:
        split_function = lambda row: row[column] == value

    set_1 = [row for row in rows if split_function(row)]
    set_2 = [row for row in rows if not split_function(row)]

    return set_1, set_2

unique_counts对分类结果计数

def unique_counts(rows):
    results = {}
    for row in rows:
        r = row[len(row) - 1]    # 分类结果:None, Basic, Premium
        if r not in results:
            results[r] = 0
        results[r] += 1
    return results

entropy计算交叉熵

def entropy(rows):
    results = unique_counts(rows)
    ent = 0.0
    for r in results.keys():
        p = float(results[r]) / len(rows)
        ent -= p * log2(p)
    return ent

3.运行测试

测试数据

my_data = [['slashdot', 'USA', 'yes', 18, 'None'],
           ['google', 'France', 'yes', 23, 'Premium'],
           ['digg', 'USA', 'yes', 24, 'Basic'],
           ['kiwibotes', 'France', 'yes', 23, 'Basic'],
           ['google', 'UK', 'no', 21, 'Premium'],
           ['(direct)', 'New Zealand', 'no', 12, 'None'],
           ['(direct)', 'UK', 'no', 21, 'Basic'],
           ['google', 'USA', 'no', 24, 'Premium'],
           ['slashdot', 'France', 'yes', 19, 'None'],
           ['digg', 'USA', 'no', 18, 'None'],
           ['google', 'UK', 'no', 18, 'None'],
           ['kiwitobes', 'UK', 'no', 19, 'None'],
           ['digg', 'New Zealand', 'yes', 12, 'Basic'],
           ['google', 'UK', 'yes', 18, 'Basic'],
           ['kiwitobes', 'France', 'yes', 19, 'Basic']]

展示结果

def print_tree(tree, indent=''):
    # 叶子结点,其results为分类结果;否则,其results为None
    if tree.results is not None:
        print(str(tree.results))
    else:
        # 打印判断条件
        print(str(tree.col) + ':' + str(tree.value) + "?")
        # 打印分支
        print(indent + "T->", end='')
        print_tree(tree.tb, indent+'  ')
        print(indent + "F->", end='')
        print_tree(tree.fb, indent+'  ')

运行结果

3:21?
T->0:google?
  T->{'Premium': 3}
  F->{'Basic': 3}
F->2:yes?
  T->0:slashdot?
    T->{'None': 2}
    F->{'Basic': 3}
  F->{'None': 4}

可以验证,决策树在训练集上准确率为100%


你可能感兴趣的:(亲手实现决策树(一))