决策树桩(Decision Stump)

1. 何为决策树桩?

 

单层决策树(decision stump),也称决策树桩,它是一种简单的决策树,通过给定的阈值进行分类。

决策树桩(Decision Stump)_第1张图片

从实际意义上来看,决策树桩根据一个属性的单个判断(但是实际上待判断的物体具有多个属性)就确定最终的分类结果。这种特性比较适合做集成学习中的弱学习器,因为其至少比随机的效果好一些,又计算较为容易。

2. 关键问题

根本目的:通过选择一个合适的决策树桩(弱学习器),使得物体类别识别准确率尽可能高。

怎么选择一个合适的决策树桩?

  1. 从所有属性中,选择那个属性作为决策树桩(弱学习器)
  2. 该决策树桩的阈值设定为何值(上图中是1.75)?
  3. 是小于阈值识别为1(yes),还是大于阈值识别为1(yes)。一般情况下设定为小于。

3. MATLAB代码实现

决策树桩(Decision Stump)_第2张图片

 

主函数test主要是数据的加载和决策树桩的调用, 并呈现最佳的决策树桩以及判断错误的概率。

首先看如何在整个数据集上获得最佳的决策树桩,使得判断错误的概率最小。

关键代码:err = sum(weight .* (predi_labels ~= label));

function [classifier, min_error, best_labels] = decision_stump(data, weight, label)
% decision_stump 确定最优决策树桩并返回分类函数
% data:num_row*num_col; weights:num_row*1(初始值均为1/num_row); labels:num_row*1(初始值为1或0)
% classifier有dim(哪一维特征识别率最高),thresh_val(阈值何值时候该维征识别率最高),thresh_ineq(是大于阈值识别高还是小于阈值识别高)
num_row = size(data, 1);
num_col = size(data, 2);
% 优化迭代次数
max_iter = 10;
% 初始化相关参数
min_error = Inf;
best_labels = ones(num_row , 1);
classifier.dim = 0;
classifier.thresh_val = 0;
classifier.thresh_ineq = '';
for i = 1:num_col
    cur_thresh_val = min(data(:, i));
    step_size = (max(data(:, i)) - min(data(:, i))) / max_iter;
    for j = 1:max_iter
        for k = ['l', 'g']
            thresh_val = cur_thresh_val + (j - 1) * step_size;
            predi_labels = decision(data, i, thresh_val, k);
            err = sum(weight .* (predi_labels ~= label));
            fprintf("iter %d dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", j, i, thresh_val, k, err);
            if err < min_error
                %更新相关参数
                min_error = err;
                best_labels = predi_labels;
                classifier.dim = i;
                classifier.thresh_val = thresh_val;
                classifier.thresh_ineq = k;
            end
        end
    end
end
end

% 根据决策树桩的阈值来分类(1,0)
function predi_labels = decision(data, i_col, thresh_val, thresh_ineq)
% data:num_row*num_col,i_col:表示第i列,thresh_val:表示决策树桩的阈值,thresh_ineq:采用大于或者小于来比较阈值
% predi_labels:根据决策树桩的阈值来获得分类后的labels,num_row*1
num_row = size(data, 1);
% 初始化predi_labels为全1
predi_labels = ones(num_row, 1);
if thresh_ineq == 'l'
    % 如果小于阈值,则判定为0
    predi_labels(data(:, i_col) <= thresh_val) = 0;
elseif thresh_ineq == 'g'
    % 如果大于阈值,则判定为0
    predi_labels(data(:, i_col) > thresh_val) = 0;
end
end

测试一下:

clc;clear;clearvars;
% 加载数据[data, label]([X,Y])
[data, label] = loadData();
% 初始化每个数据点的相对权重,创建数值均为1/num_row的num_row*num_col数组
num_row = size(data, 1);
num_col = size(data, 2);
weight = repmat(1 / num_row, num_row, 1);
[classifier, min_error, best_labels] = decision_stump(data, weight, label);
% 输出最佳的决策树桩相关参数
fprintf("dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", ...
    classifier.dim, classifier.thresh_val, classifier.thresh_ineq, min_error);
disp(best_labels);


function [data, label] = loadData()
% data:5*2,label:5*1,标签为0,1.
data = [1, 2.1; 1.5, 1.6; 1.3, 1; 1, 1; 2, 1];
label = [1; 1; 0 ; 0 ; 1];
end

iter 1 dim 1, threshVal 1.00, thresh ineqal: l, the weighted error is 0.400
iter 1 dim 1, threshVal 1.00, thresh ineqal: g, the weighted error is 0.600
iter 2 dim 1, threshVal 1.10, thresh ineqal: l, the weighted error is 0.400
iter 2 dim 1, threshVal 1.10, thresh ineqal: g, the weighted error is 0.600
iter 3 dim 1, threshVal 1.20, thresh ineqal: l, the weighted error is 0.400
iter 3 dim 1, threshVal 1.20, thresh ineqal: g, the weighted error is 0.600
iter 4 dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200
iter 4 dim 1, threshVal 1.30, thresh ineqal: g, the weighted error is 0.800
iter 5 dim 1, threshVal 1.40, thresh ineqal: l, the weighted error is 0.200
iter 5 dim 1, threshVal 1.40, thresh ineqal: g, the weighted error is 0.800
iter 6 dim 1, threshVal 1.50, thresh ineqal: l, the weighted error is 0.400
iter 6 dim 1, threshVal 1.50, thresh ineqal: g, the weighted error is 0.600
iter 7 dim 1, threshVal 1.60, thresh ineqal: l, the weighted error is 0.400
iter 7 dim 1, threshVal 1.60, thresh ineqal: g, the weighted error is 0.600
iter 8 dim 1, threshVal 1.70, thresh ineqal: l, the weighted error is 0.400
iter 8 dim 1, threshVal 1.70, thresh ineqal: g, the weighted error is 0.600
iter 9 dim 1, threshVal 1.80, thresh ineqal: l, the weighted error is 0.400
iter 9 dim 1, threshVal 1.80, thresh ineqal: g, the weighted error is 0.600
iter 10 dim 1, threshVal 1.90, thresh ineqal: l, the weighted error is 0.400
iter 10 dim 1, threshVal 1.90, thresh ineqal: g, the weighted error is 0.600
iter 1 dim 2, threshVal 1.00, thresh ineqal: l, the weighted error is 0.200
iter 1 dim 2, threshVal 1.00, thresh ineqal: g, the weighted error is 0.800
iter 2 dim 2, threshVal 1.11, thresh ineqal: l, the weighted error is 0.200
iter 2 dim 2, threshVal 1.11, thresh ineqal: g, the weighted error is 0.800
iter 3 dim 2, threshVal 1.22, thresh ineqal: l, the weighted error is 0.200
iter 3 dim 2, threshVal 1.22, thresh ineqal: g, the weighted error is 0.800
iter 4 dim 2, threshVal 1.33, thresh ineqal: l, the weighted error is 0.200
iter 4 dim 2, threshVal 1.33, thresh ineqal: g, the weighted error is 0.800
iter 5 dim 2, threshVal 1.44, thresh ineqal: l, the weighted error is 0.200
iter 5 dim 2, threshVal 1.44, thresh ineqal: g, the weighted error is 0.800
iter 6 dim 2, threshVal 1.55, thresh ineqal: l, the weighted error is 0.200
iter 6 dim 2, threshVal 1.55, thresh ineqal: g, the weighted error is 0.800
iter 7 dim 2, threshVal 1.66, thresh ineqal: l, the weighted error is 0.400
iter 7 dim 2, threshVal 1.66, thresh ineqal: g, the weighted error is 0.600
iter 8 dim 2, threshVal 1.77, thresh ineqal: l, the weighted error is 0.400
iter 8 dim 2, threshVal 1.77, thresh ineqal: g, the weighted error is 0.600
iter 9 dim 2, threshVal 1.88, thresh ineqal: l, the weighted error is 0.400
iter 9 dim 2, threshVal 1.88, thresh ineqal: g, the weighted error is 0.600
iter 10 dim 2, threshVal 1.99, thresh ineqal: l, the weighted error is 0.400
iter 10 dim 2, threshVal 1.99, thresh ineqal: g, the weighted error is 0.600
dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200
     0
     1
     0
     0
     1

可以看到最佳佳决策树桩选择的是第一个特征,阈值1.30,识别错误的概率是0.2,第一个识别错误。

你可能感兴趣的:(最优化:建模,算法和理论,决策树)