Pocket.m
%% Pocket算法
%X,Y--训练数据集
%px,py--预测数据
function [] = pocket(X, Y, px, py)
if abs(Y) ~= 1
error('invalid y,y should be 1 or -1');
return;
end
[numParam, numSample] = size(X); %样本数量
itCount = 0; %迭代次数
minNumError = numSample;
W=zeros(numParam+1, 1); %假设threhold=0
while itCount<100
%找到所有错误点个数
numError = 0;
errorPos = -1;
for i=1:numSample
x = [X(:, i)',1]';%x0=1
if sign(W'*x) ~= Y(i)
numError=numError+1;
errorPos = i;
end
end
fprintf('%d, numError:%d,minNumError=%d\n', itCount, numError, minNumError);
%更新最小错误点个数
if minNumError > numError
minNumError = numError;
end
if minNumError == 0
break;
end
%用最后一个错误点更新W
W = W + Y(errorPos)*[X(:,errorPos)', 1]';
itCount = itCount+1;
end
fprintf('Pocket result W:');
W'
predict_y=sign(W'*[px',1]');
fprintf('Pocket predict result:yn=%d, predictyn=%d\n', py, predict_y);
PLA.m:PLA与Pocket算法只有在while循环部分不一样:
while itCount<100
hasError=0;
%找到第一个错误点个数
for i=1:numSample
x = [X(:, i)',1]';%x0=1
if sign(W'*x) ~= Y(i)
%更新W
W = W + Y(i)*[X(:,i)', 1]';
hasError=1;
end
end
if hasError==0
break;
end
itCount = itCount+1;
end
fprintf('PLA itCount:%d,result W:', itCount);
main.m
%读取测试数据,最后1条作为测试数据,前面的作为训练数据
%Sample =xlsread('Cryotherapy.xlsx', 'CrayoDataset', 'A2:G91');
Sample=load('data_banknote_authentication.txt');
[row, column] = size(Sample);
for i=1:row
if Sample(i, column) == 0
Sample(i, column)=1;
end
end
%行表示数据特征个数,列表示样本个数
trainX=Sample(1:row-1,1:column-1)';
trainY=Sample(1:row-1,column)';
predictX=Sample(row, 1:column-1)';
predictY=Sample(row,column);
%测试Pocket方法
pocket(trainX, trainY, predictX,predictY);
%测试PLA方法
PLA(trainX, trainY, predictX,predictY);
测试数据:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
说明:鸢尾花数据集分为3类,一类与其他两类线性可分,可将第一类替换为1,第二三类替换为-1,再做训练。