matlab决策树 id3算法实现多叉树树形图显示

matlab 决策树 id3算法实现多叉树树形图显示

受一个科大同学之情,是科大《机器学习》这门课的课程作业之一,暑假在家抽时间完成了这个matlab的版本。略有不足,还望多多海涵。感觉网上关于matlab问题的解答不是很多,大家共同努力吧!自己琢磨了挺久的,颇感遗憾。
感觉对你有帮助的可以去下载一下我的代码,谢谢。
https://download.csdn.net/download/justsolow/11214000
也是科大的《模式识别》课程的大作业,用python写的GUI界面K-means聚类。
https://download.csdn.net/download/justsolow/11530044
这是本文的完整代码下载链接,感谢支持。

借鉴:https://www.cnblogs.com/Kermit-Li/p/4503427.html
数据集
https://blog.csdn.net/lfdanding/article/details/50753239
在此作者基础上做了一定的改进,实现多叉树。

这是实际运行的结果:
matlab决策树 id3算法实现多叉树树形图显示_第1张图片
main.m主函数

clc;
clear all;
close all;


%% 数据预处理
disp('正在进行数据预处理...');
[matrix,attributes_label,attributes] =  id3_preprocess();

%% 构造ID3决策树,其中id3()为自定义函数
disp('数据预处理完成,正在进行构造树...');
tree = decissiontree(matrix,attributes_label,attributes);
%% 打印并画决策树
[nodeids,nodevalues] = print_tree(tree);
tree_plot(nodeids,nodevalues);

disp('ID3算法构建决策树完成!');

decissiontree.m主要的执行函数
注释部分是另外一种实现方法用cell形式取出数据再转换成结构体。

function [ tree ] = decissiontree(train_data,labels,activeAttributes)
%input                 train_data          训练数据
%labels                标签
%activeAttributes       活跃属性
%output 
%% 数据预处理
[m,n] = size(train_data);
disp('original data');
disp(train_data);

%% 建立决策树
%% 结构体定义

% 创建树节点
% tree = struct('value','null');
% 提供的数据为空,则报异常
if (isempty(train_data))
    error('必须提供数据!');
end

% 常量
numberAttributes = length(activeAttributes);
numberExamples = length(train_data(:,1));


% 如果最后一列全部为1,则返回“true”
lastColumnSum = sum(train_data(:, numberAttributes + 1));

if (lastColumnSum == numberExamples);
    tree.value = 'true';
    tree.children = 'null';
    return
end
% 如果最后一列全部为0,则返回“false”
if (lastColumnSum == 0);
    tree.value = 'false';
    tree.children = 'null';
    return
end

% 如果活跃的属性为空,则返回label最多的属性值
if (sum(activeAttributes) == 0);
    if (lastColumnSum >= numberExamples / 2);
        tree.value = 'true';
        tree.children = 'null';
    else
        tree.value = 'false';
        tree.children = 'null';
    end
    return
end
bestfeats = choose_bestfeat(train_data);
disp(['bestfeat:',num2str(bestfeats)]);
tree.value = labels{bestfeats};
disp(['bestfeature:',num2str(bestfeats)]);
activeAttributes(bestfeats) = 0;
featvalue = unique(train_data(:,bestfeats));
featvalue_num = length(featvalue);
filed = {'children'};
%labels=[labels(1:bestfeats-1) labels(bestfeats+1:length(labels))];
for i=0:featvalue_num-1
    example = train_data(train_data(:,bestfeats) == i,:);
    leaf = struct('value', 'null');
    % 当 value = false or 0, 左分支
if (isempty(example));
    if (lastColumnSum >= numberExamples / 2); % for matrix examples
        leaf.value = 'true';
        leaf.children = 'null';
    else
        leaf.value = 'false';
        leaf.children = 'null';
    end
    tree.children(i+1) = leaf;
else
    % 递归
%     if class(tree.children) == 'struct'
    tree.children(i+1) = decissiontree(example,labels,activeAttributes);
%     end
%     if class(tree.children) == 'cell'
%     tree.children{i+1} = decissiontree(example,labels,activeAttributes);
% %     end
%     if i>=1;
%     tree.children = cell2struct(tree.children,filed,1);
%     end
    disp('--------------------------------------------');
end
end
%返回
return
end


数据处理

function [ matrix,attributes,activeAttributes ] = id3_preprocess(  )
%% ID3算法数据预处理,把字符串转换为0,1编码

% 输出参数:
% matrix: 转换后的0,1矩阵;
% attributes: 属性和Label;
% activeAttributes : 属性向量,全1;

%% 读取数据
% txt = {  '序号'    '天气'    '是否周末'    '是否有促销'    '销量'
%         ''        '坏'      '是'          '是'            '高'  
%         ''        '坏'      '是'          '是'            '高'  
%         ''        '坏'      '是'          '是'            '高'  
%         ''        '坏'      '否'          '是'            '高'  
%         ''        '坏'      '是'          '是'            '高'  
%         ''        '坏'      '否'          '是'            '高'  
%         ''        '坏'      '是'          '否'            '高'  
%         ''        '好'      '是'          '是'            '高'  
%         ''        '好'      '是'          '否'            '高'  
%         ''        '好'      '是'          '是'            '高'  
%         ''        '好'      '是'          '是'            '高'  
%         ''        '好'      '是'          '是'            '高'  
%         ''        '好'      '是'          '是'            '高'  
%         ''        '坏'      '是'          '是'            '低'  
%         ''        '好'      '否'          '是'            '高'  
%         ''        '好'      '否'          '是'            '高'  
%         ''        '好'      '否'          '是'            '高'  
%         ''        '好'      '否'          '是'            '高'  
%         ''        '好'      '否'          '否'            '高'  
%         ''        '坏'      '否'          '否'            '低'  
%         ''        '坏'      '否'          '是'            '低'  
%         ''        '坏'      '否'          '是'            '低'  
%         ''        '坏'      '否'          '是'            '低'  
%         ''        '坏'      '否'          '否'            '低'  
%         ''        '坏'      '是'          '否'            '低'  
%         ''        '好'      '否'          '是'            '低'  
%         ''        '好'      '否'          '是'            '低'  
%         ''        '坏'      '否'          '否'            '低'  
%         ''        '坏'      '否'          '否'            '低'  
%         ''        '好'      '否'          '否'            '低'  
%         ''        '坏'      '是'          '否'            '低'  
%         ''        '好'      '否'          '是'            '低'  
%         ''        '好'      '否'          '否'            '低'  
%         ''        '好'      '否'          '否'            '低'  }
txt = {     '天气','温度','湿度','风速','是否出门' 
             'sunny','hot','high','week','no';
             'sunny','hot','high','strong','no';
             'overcast','hot','high','week','yes';
             'rain','midd','high','week','yes';
             'rain','cool','nomal','week','yes';
             'rain','cool','nomal','strong','no';
             'overcast','cool','nomal','strong','yes';
             'sunny','midd','high','week','no';
             'sunny','cool','nomal','week','yes';
             'rain','midd','nomal','week','yes';
             'sunny','midd','nomal','strong','yes';
             'overcast','midd','high','strong','yes';
             'overcast','hot','nomal','week','yes';
             'rain','midd','high','strong','no'};
%sunuy-0,overcast-1,rain-2;--hot-2,midd-1,cool-2---high-0,nomal-1--week-0,strong-1,no-0,yes-1

attributes=txt(1,1:end);
activeAttributes = ones(1,length(attributes)-1);
data = txt(2:end,1:end);
% attributes=txt(1,2:end);
% activeAttributes = ones(1,length(attributes)-1);
% data = txt(2:end,2:end);

%% 针对每列数据进行转换
[rows,cols] = size(data);
matrix = zeros(rows,cols);
for j=1:cols
    matrix(:,j) = cellfun(@trans2onezero,data(:,j));
end

end
%sunuy-0,overcast-1,rain-2;--hot-2,midd-1,cool-2---high-0,nomal-1--week-0,strong-1,no-0,yes-1
function flag = trans2onezero(data)
%     if strcmp(data,'坏') ||strcmp(data,'否')...
%         ||strcmp(data,'低')
%         flag =0;
%         return;
     if strcmp(data,'sunny') || strcmp(data,'high') || strcmp(data,'week') || strcmp(data,'no') || strcmp(data,'cool')
         flag = 0;
         return;
    end
    if strcmp(data,'rain') || strcmp(data,'hot')
        flag = 2;
        return;
    end
    flag =1;
end

取出最佳属性列

function [best_feature] = choose_bestfeat(data)
%input                 data                        输入数据
%output               bestfeature             选择特征值

[m,n] = size(data);
feature_num = n - 1;
baseentropy = calc_entropy(data);

best_gain = 0;
best_feature = 0;

%% 挑选最佳特征位
for j =1:feature_num
    feature_temp = unique(data(:,j));
    num_f = length(feature_temp);
    new_entropy = 0;
    for i = 1:num_f
        subSet = splitData(data, j, feature_temp(i,:));
        [m_s,n_s] = size(subSet);
        prob = m_s./m;
        new_entropy = new_entropy + prob * calc_entropy(subSet);
    end
    %信息增益=信息熵-条件熵
    inf_gain = baseentropy - new_entropy;
    if inf_gain > best_gain
        best_gain = inf_gain;
        best_feature = j;
    end
end
end

function [subSet] = splitData(data, j, value)
%input                 data              训练数据
%input                  j                   对应第j个属性
%input                 value             第j个属性对应的特征值

subSet = data;
subSet(:,j) = [];
k = 0;
for i = 1:size(data,1)
    if data(i,j) ~= value
        subSet(i-k,:) =[];
        k = k + 1;
    end
end
end

信息熵

function [entropy] = calc_entropy(train_data)
%input                 train_data          训练数据
%output               entropy             熵值

[m,n] = size(train_data);

%% 得到类的项并统计每个类的个数
label_value = train_data(:,n);
label = unique(label_value);
label_number = zeros(length(label),2);
label_number(:,1) = label';
for i = 1:length(label)
    label_number(i,2) = sum(label_value == label(i));
end

%% 计算熵值
label_number (:,2) = label_number(:,2) ./ m;
entropy = 0;
entropy = sum(-label_number(:,2).*log2 (label_number(:,2)));

end

出入队列,提取出结构体数组中的元素

function [ newqueue ] = queue_push( queue,item )
%% 进队

% cols = size(queue);
% newqueue =structs(1,cols+1);
newqueue=[queue,item];

end

function [ item,newqueue ] = queue_pop( queue )
%% 访问队列

if isempty(queue)
    disp('队列为空,不能访问!');
    return;
end

item = queue(1); % 第一个元素弹出
newqueue=queue(2:end); % 往后移动一个元素位置

end

function [ length_ ] = queue_curr_size( queue )
%% 当前队列长度

length_= length(queue);

end


画图函数

function [nodeids_,nodevalue_] = print_tree(tree)
%% 打印树,返回树的关系向量
global nodeid nodeids nodevalue;
nodeids(1)=0; % 根节点的值为0
nodeid=0;
nodevalue={};
if isempty(tree) 
    disp('空树!');
    return ;
end

queue = queue_push([],tree);
while ~isempty(queue) % 队列不为空
     [node,queue] = queue_pop(queue); % 出队列
     visit(node,queue_curr_size(queue));
     if ~strcmp(node.children,'null')
         queue = queue_push(queue,node.children); % 进队
             
%      if ~strcmp(node.children,'null')
%          for i=1:length(node.children)
%              if ~strcmp(node.children(i).children,'null') % 子树不为空
%                 queue = queue_push(queue,node.children(i).children); % 进队
%              end
%          end     
     end   
end

%% 返回 节点关系,用于treeplot画图
nodeids_=nodeids;
nodevalue_=nodevalue;
end

function visit(node,length_)
%% 访问node 节点,并把其设置值为nodeid的节点
    global nodeid nodeids nodevalue;
%     if isleaf(node)
    if strcmp(node.children,'null')
        nodeid=nodeid+1;
        fprintf('叶子节点,node: %d\t,属性值: %s\n', ...
        nodeid, node.value);
        nodevalue{1,nodeid}=node.value;
    else % 要么是叶子节点,要么不是
    nodeid=nodeid+1;
    for i=1:length(node.children)
        nodeids(nodeid+length_+i)=nodeid;
%         nodeids(nodeid+length_+2)=nodeid;
        fprintf('node: %d\t属性值: %s\t,子树为节点:node%d', ...
        nodeid, node.value,nodeid+length_+i);
        fprintf('\n');
        nodevalue{1,nodeid}=node.value;
    end
    end
end

function flag = isleaf(node)
%% 是否是叶子节点
    if strcmp(node.children,'null') % 左右都为空
        flag =1;
    else
        flag=0;
    end
end

function tree_plot( p ,nodevalues)
%% 参考treeplot函数

[x,y,h]=treelayout(p);
f = find(p~=0);
pp = p(f);
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];

X = X(:);
Y = Y(:);

    n = length(p);
    if n < 500,
        hold on ; 
        plot (x, y, 'ro', X, Y, 'r-');
        nodesize = length(x);
        for i=1:nodesize
%            text(x(i)+0.01,y(i),['node' num2str(i)]); 
            text(x(i)+0.01,y(i),nodevalues{1,i}); 
        end
        hold off;
    else
        plot (X, Y, 'r-');
    end;

xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);

end

你可能感兴趣的:(机器学习)