吴恩达471机器学习入门课程2第4周——决策树

决策树

    • 1 导包
    • 2 问题描述
    • 3 one-hot 编码数据集
    • 4决策树
      • 4.1 计算熵
      • 4.2 划分数据集
    • 5 构建树

从头开始实现决策树,并将其应用于对蘑菇是可食用还是有毒的分类任务。

1 导包

import numpy as np
import matplotlib.pyplot as plt
from public_tests import *
%matplotlib inline

2 问题描述

假设您正在创办一家种植和销售野生蘑菇的公司。 由于并非所有蘑菇都可以食用,因此您希望能够根据其物理属性来判断给定的蘑菇是可食用的还是有毒的 您有一些可用于此任务的现有数据。 你能用这些数据来帮助你确定哪些蘑菇可以安全销售吗? 注意:使用的数据集仅用于说明目的。它并不意味着作为识别食用蘑菇的指南。

3 one-hot 编码数据集

Brown Cap Tapering Stalk Shape Solitary Edible
1 1 1 1
1 0 1 1
1 0 0 0
1 0 0 0
1 1 1 1
0 1 1 0
0 0 0 0
1 0 1 1
0 1 0 1
1 0 0 0
因此,
  • X_train 包含每个样本的三个特征

    • 帽子颜色(值为 1 表示棕色帽子,值为 0 表示红色帽子)
    • 茎形缩小(值为 1 表示“锥形茎形”,值为 0 表示“扩大”茎形)
    • 单独(值为 1 表示“是”,值为 0 表示“否”)
  • y_train 是蘑菇是否可食用

    • y = 1 表示可食用
    • y = 0 表示有毒
X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
X_train[:5]
array([[1, 1, 1],
       [1, 0, 1],
       [1, 0, 0],
       [1, 0, 0],
       [1, 1, 1]])

4决策树

在这个实践实验中,您将基于提供的数据集构建一个决策树。

  • 回想一下构建决策树的步骤:

    • 从根节点开始使用所有的样本
    • 计算基于所有可能特征分割时的信息增益,并选择具有最高信息增益的特征
    • 根据所选特征对数据集进行分割,并创建树的左右分支
    • 持续重复分裂过程直到满足停止标准
  • 在本实验中,您将实现以下函数,以便使用具有最高信息增益的特征将节点分成左右两个分支:

    • 计算节点上的熵
    • 基于给定特征在一个节点上将数据集分为左右两个分支
    • 计算在给定特征上分裂时的信息增益
    • 选择最大化信息增益的特征
  • 然后我们将使用您实现的辅助函数通过重复分裂过程来构建决策树,直到满足停止标准为止。

    • 对于本实验,我们选择的停止标准是设置最大深度为2。

4.1 计算熵

首先,您需要编写一个名为compute_entropy的帮助函数,用于计算节点上的熵(杂质度量)。

  • 函数接受一个numpy数组(y),该数组指示该节点中的示例是否可食用(1)或有毒(0)

请完成下面的compute_entropy()函数以:

  • 计算 p 1 p_1 p1,它是可食用示例(即在y中具有值= 1)的比例
  • 然后计算熵

H ( p 1 ) = − p 1 log 2 ( p 1 ) − ( 1 − p 1 ) log 2 ( 1 − p 1 ) H(p_1) = -p_1 \text{log}_2(p_1) - (1- p_1) \text{log}_2(1- p_1) H(p1)=p1log2(p1)(1p1)log2(1p1)

  • 注意
    • 对数使用基数为 2 2 2
    • 为了实现方便, 0 log 2 ( 0 ) = 0 0\text{log}_2(0) = 0 0log2(0)=0。也就是说,如果p_1 = 0 或者 p_1 = 1,则将熵设为0
    • 确保检查节点上的数据不是空的(即len(y) != 0),如果为空则返回0
def compute_entropy(y):
    entropy = 0
    if len(y) != 0:
        p1 = len(y[y==1])/len(y)
        if p1 != 0 and p1 !=1:
            entropy = -p1 * np.log2(p1)-(1-p1)*np.log2(1-p1)
        else:
            entropy = 0
    return entropy

4.2 划分数据集

下一步,您将编写一个名为 split_dataset 的帮助函数,该函数获取节点处的数据和要拆分的特征,并将其拆分为左右分支。稍后在实验室中,您将实现计算拆分效果如何的代码。

  • 该函数获取训练数据、该节点上数据点的索引列表以及要拆分的特征。
  • 它将数据拆分并返回左右分支的索引子集。
  • 例如,假设我们从根节点开始(因此 node_indices = [0,1,2,3,4,5,6,7,8,9]),并选择拆分特征为 0,即示例是否具有棕色帽子。
    • 然后函数的输出是,left_indices = [0,1,2,3,4,7,9]right_indices = [5,6,8]
索引 棕色帽子 收缩柄形状 独立 可食用
0 1 1 1 1
1 1 0 1 1
2 1 0 0 0
3 1 0 0 0
4 1 1 1 1
5 0 1 1 0
6 0 0 0 0
7 1 0 1 1
8 0 1 0 1
9 1 0 0 0
def split_dataset(X,node_indices,feature):
    Lnode_indices = []
    Rnote_indices = []
    for i in node_indices:
        if X[i][feature] ==1:
            Lnode_indices.append(i)
        else:
            Rnote_indices.append(i)
    return Lnode_indices,Rnote_indices

接下来,您将编写一个名为 information_gain 的函数,该函数接受训练数据、节点上的索引和要拆分的特征,并返回从拆分中获得的信息增益。

请完成下面所示的 compute_information_gain() 函数以计算

信息增益 = H ( p 1 node ) − ( w left H ( p 1 left ) + w right H ( p 1 right ) ) \text{信息增益} = H(p_1^\text{node})- (w^{\text{left}}H(p_1^\text{left}) + w^{\text{right}}H(p_1^\text{right})) 信息增益=H(p1node)(wleftH(p1left)+wrightH(p1right))

其中:

  • H ( p 1 node ) H(p_1^\text{node}) H(p1node) 是节点的熵
  • H ( p 1 left ) H(p_1^\text{left}) H(p1left) H ( p 1 right ) H(p_1^\text{right}) H(p1right) 是拆分后左、右分支的熵
  • w left w^{\text{left}} wleft w right w^{\text{right}} wright 分别是左、右分支中例子的比例

注意:

  • 您可以使用上面实现的 compute_entropy() 函数来计算熵
  • 我们提供了一些起始代码,该代码使用您之前实现的 split_dataset() 函数拆分数据集
def compute_information_gain(X,y,node_indices,feature):
    L_indices,R_indices = split_dataset(X,node_indices,feature)

    X_node, y_node = X[node_indices],y[node_indices]
    X_left, y_left = X[L_indices], y[L_indices]
    X_right, y_right = X[R_indices], y[R_indices]

    node_entropy = compute_entropy(y_node)
    L_entropy = compute_entropy(y_left)
    R_entropy = compute_entropy(y_right)

    left_w = len(X_left)/len(X_node)
    rigth_w = len(X_right)/len(X_node)

    w_entropy = left_w*L_entropy +rigth_w*R_entropy
    information_gain = node_entropy - w_entropy
    return information_gain

现在,让我们编写一个函数来获取最佳划分特征,方法是计算每个特征的信息增益,就像上面做的那样,并返回给出最大信息增益的特征。

请完成下面所示的 get_best_split() 函数。

  • 函数接受训练数据以及节点处数据点的索引
  • 函数的输出是提供最大信息增益的特征
    • 您可以使用 compute_information_gain() 函数遍历特征并为每个特征计算信息量
def get_best_split(X,y,node_indices):
    num = X.shape[1]
    best_feature = -1
    max_info_gain = 0
    for feature in range(num):
        info_gain = compute_information_gain(X,y,node_indices,feature)
        if info_gain > max_info_gain:
            max_info_gain = info_gain
            best_feature = feature
    return best_feature

5 构建树

在本节中,我们使用您在上面实现的函数来生成决策树,方法是依次选择要拆分的最佳特征,直到达到停止条件(最大深度为 2)。

tree = []
root_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth):
    if current_depth == max_depth:
        formatting = " "*current_depth + "-"*current_depth
        print(formatting, "%s leaf node with indices" % branch_name, node_indices)
        return


    best_feature = get_best_split(X, y, node_indices)
    tree.append((current_depth, branch_name, best_feature, node_indices))

    formatting = "-"*current_depth
    print("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))

    # Split the dataset at the best feature
    left_indices, right_indices = split_dataset(X, node_indices, best_feature)

    # continue splitting the left and the right child. Increment current depth
    build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)
    build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)
build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)
 Depth 0, Root: Split on feature: 2
- Depth 1, Left: Split on feature: 0
  -- Left leaf node with indices [0, 1, 4, 7]
  -- Right leaf node with indices [5]
- Depth 1, Right: Split on feature: 1
  -- Left leaf node with indices [8]
  -- Right leaf node with indices [2, 3, 6, 9]

你可能感兴趣的:(机器学习,决策树,机器学习,python)