结果为:
%% CART clear all clc % 导入数据集 %dataSet = load('ex00.txt'); dataSet = load('ex0.txt'); % 画图1 % plot(dataSet(:,1),dataSet(:,2),'.'); % axis([-0.2,1.2,-1.0,2.0]); % 画图2 % plot(dataSet(:,2),dataSet(:,3),'.'); % axis([-0.2,1.2,-1.0,5.0]); createTree(dataSet,1,4);
function [ retTree ] = createTree( dataSet,tolS,tolN ) [feat,val] = chooseBestSplit(dataSet, tolS, tolN); disp(['feat:', num2str(feat)]); disp(['value:', num2str(val)]); if feat == 0 return; end [lSet,rSet] = binSplitDataSet(dataSet, feat, val); disp('left:'); createTree( lSet,tolS,tolN ); disp('right:'); createTree( rSet,tolS,tolN ); end
function [ Index, Value ] = chooseBestSplit( dataSet, tolS, tolN ) % 参数中tolS是容许的误差下降值,tolN是切分的最小样本数 m = size(dataSet);%数据集的大小 if length(unique(dataSet(:,m(:,2)))) == 1%仅剩下一种时 Index = 0; Value = regLeaf(dataSet(:,m(:,2))); return; end S = regErr(dataSet);%误差 bestS = inf;%初始化,无穷大 bestIndex = 0; bestValue = 0; %找到最佳的位置和最优的值 for j = 1:(m(:,2)-1)%得到列 b = unique(dataSet(:,j));%得到特征所在的列 lenCharacter = length(b); for i = 1:lenCharacter temp = b(i,:); [mat0,mat1] = binSplitDataSet(dataSet, j ,temp); m0 = size(mat0); m1 = size(mat1); if m0(:,1) < tolN || m1(:,1) < tolN continue; end newS = regErr(mat0) + regErr(mat1); if newS < bestS bestS = newS; bestIndex = j; bestValue = temp; end end end if (S-bestS) < tolS Index = 0; Value = regLeaf(dataSet(:,m(:,2))); return; end %划分 [mat0, mat1] = binSplitDataSet(dataSet, bestIndex ,bestValue); m0 = size(mat0); m1 = size(mat1); if m0(:,1) < tolN || m1(:,1) < tolN Index = 0; Value = regLeaf(dataSet(:,m(:,2))); return; end Index = bestIndex; Value = bestValue; end
%% 将数据集划分为两个部分 function [ dataSet_1, dataSet_2 ] = binSplitDataSet( dataSet, feature, value ) [m,n] = size(dataSet);%计算数据集的大小 DataTemp = dataSet(:,feature)';%变成行 %计算行中标签列的元素大于value的行 index_1 = [];%空的矩阵 index_2 = []; for i = 1:m if DataTemp(1,i) > value index_1 = [index_1,i]; else index_2 = [index_2,i]; end end [m_1,n_1] = size(index_1);%这里要取列数 [m_2,n_2] = size(index_2); if n_1>0 && n_2>0 for j = 1:n_1 dataSet_1(j,:) = dataSet(index_1(1,j),:); end for j = 1:n_2 dataSet_2(j,:) = dataSet(index_2(1,j),:); end elseif n_1 == 0 dataSet_1 = []; dataSet_2 = dataSet; elseif n_2 == 0 dataSet_2 = []; dataSet_1 = dataSet; end end %% 将数据集划分为两个部分 function [ dataSet_1, dataSet_2 ] = binSplitDataSet( dataSet, feature, value ) [m,n] = size(dataSet);%计算数据集的大小 DataTemp = dataSet(:,feature)';%变成行 %计算行中标签列的元素大于value的行 index_1 = [];%空的矩阵 index_2 = []; for i = 1:m if DataTemp(1,i) > value index_1 = [index_1,i]; else index_2 = [index_2,i]; end end [m_1,n_1] = size(index_1);%这里要取列数 [m_2,n_2] = size(index_2); if n_1>0 && n_2>0 for j = 1:n_1 dataSet_1(j,:) = dataSet(index_1(1,j),:); end for j = 1:n_2 dataSet_2(j,:) = dataSet(index_2(1,j),:); end elseif n_1 == 0 dataSet_1 = []; dataSet_2 = dataSet; elseif n_2 == 0 dataSet_2 = []; dataSet_1 = dataSet; end end
function [ error ] = regErr( dataSet ) m = size(dataSet);%求得dataSet的大小 dataVar = var(dataSet(:,m(:,2))); error = dataVar * (m(:,1)-1); end
function [ leaf ] = regLeaf( dataSet ) m = size(dataSet); leaf = mean(dataSet(:,m(:,2))); end