使用matlab制作感知机对鸢尾花数据集进行分类

在由鸢尾花数据集组成的四维空间中,使用若干个超平面将不同的数据两两隔开,如隔开0/1类的感知机命名为Y01,若为0类则输出尽可能接近0.1,若为1类则输出尽可能接近0.9。运行程序后分类正确率为90%
其中iris_training.mat链接:
https://pan.baidu.com/s/1U1FghBjTvjeOqY0QTDCFSw?pwd=w14x
提取码:w14x

%%单层感知机算法

%线性分类
%建立三个超平面划分种类0/1,0/2,1/2
%设定误差为 s = 0.03 ,学习率为 a = 0.05
%训练过程为选取一个样本如种类0,将其投入0/1,0/2两个感知机中进行训练,将阈值同样作为一个参数进行训练
%激活函数选用sigmoid函数
%选取每种各30个样本进行训练,剩下样本进行测试
hold on
load("iris_training.mat")
Sepal_Length = iristraining(:,1);                  % 花萼长度
Sepal_Width = iristraining(:,2);                   % 花萼宽度
Petal_Length = iristraining(:,3);                  % 花瓣长度
Petal_Width = iristraining(:,4);                   % 花瓣宽度​
scatter3(Sepal_Length,Sepal_Width,Petal_Length,40,Petal_Width,'filled')    % draw the scatter plot
ax = gca;
ax.XDir = 'reverse';
view(-31,14)
xlabel('花萼长度')
ylabel('花萼宽度')
zlabel('花瓣长度')

%初始化参数
w1 = [0.5 0.5 0.5 0.5 0.5];
w2 = [0.5 0.5 0.5 0.5 0.5];
w3 = [0.5 0.5 0.5 0.5 0.5];

%设定正误判定标准
d_correct = 0.9;
d_wrong = 0.1;

%设定误差、学习率
s = 0.03;
a = 0.02;

%训练一轮
err = 1;
gen = 0;
while err>s
    gen = gen + 1
    err = 0;
    for i = 1:30
        Y01 = logsig(w1(1)*Sepal_Length(i) + w1(2)*Sepal_Width(i) + w1(3)*Petal_Length(i) + w1(4)*Petal_Width(i) - w1(5));
        Y02 = logsig(w2(1)*Sepal_Length(i) + w2(2)*Sepal_Width(i) + w2(3)*Petal_Length(i) + w2(4)*Petal_Width(i) - w2(5));
        w1 = w1 - a*(Y01-d_wrong)*(Y01*(1-Y01))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        w2 = w2 - a*(Y02-d_wrong)*(Y02*(1-Y02))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        err = err + max([(Y01-d_wrong)^2 (Y02-d_wrong)^2]);
    end
    for i = 41:70
        Y01 = logsig(w1(1)*Sepal_Length(i) + w1(2)*Sepal_Width(i) + w1(3)*Petal_Length(i) + w1(4)*Petal_Width(i) - w1(5));
        Y12 = logsig(w3(1)*Sepal_Length(i) + w3(2)*Sepal_Width(i) + w3(3)*Petal_Length(i) + w3(4)*Petal_Width(i) - w3(5));
        w1 = w1 - a*(Y01-d_correct)*(Y01*(1-Y01))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        w3 = w3 - a*(Y12-d_wrong)*(Y12*(1-Y12))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        err = err + max([(Y01-d_correct)^2 (Y12-d_wrong)^2]);
    end
    for i = 81:110
        Y02 = logsig(w2(1)*Sepal_Length(i) + w2(2)*Sepal_Width(i) + w2(3)*Petal_Length(i) + w2(4)*Petal_Width(i) - w2(5));
        Y12 = logsig(w3(1)*Sepal_Length(i) + w3(2)*Sepal_Width(i) + w3(3)*Petal_Length(i) + w3(4)*Petal_Width(i) - w3(5));
        w2 = w2 - a*(Y02-d_correct)*(Y02*(1-Y02))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        w3 = w3 - a*(Y12-d_correct)*(Y12*(1-Y12))*[Sepal_Length(i) Sepal_Width(i) Petal_Length(i) Petal_Width(i) -1];
        err = err + max([(Y02-d_correct)^2 (Y12-d_correct)^2]);
    end
    err = err/90
end

correct = 0; 
for i = 31:40
    Y01 = logsig(w1(1)*Sepal_Length(i) + w1(2)*Sepal_Width(i) + w1(3)*Petal_Length(i) + w1(4)*Petal_Width(i) - w1(5));
    Y02 = logsig(w2(1)*Sepal_Length(i) + w2(2)*Sepal_Width(i) + w2(3)*Petal_Length(i) + w2(4)*Petal_Width(i) - w2(5));
    Y12 = logsig(w3(1)*Sepal_Length(i) + w3(2)*Sepal_Width(i) + w3(3)*Petal_Length(i) + w3(4)*Petal_Width(i) - w3(5));
    if Y01<=0.5 && Y02<=0.5
        correct = correct + 1;
    end
end
correct
for i = 71:80
    Y01 = logsig(w1(1)*Sepal_Length(i) + w1(2)*Sepal_Width(i) + w1(3)*Petal_Length(i) + w1(4)*Petal_Width(i) - w1(5));
    Y02 = logsig(w2(1)*Sepal_Length(i) + w2(2)*Sepal_Width(i) + w2(3)*Petal_Length(i) + w2(4)*Petal_Width(i) - w2(5));
    Y12 = logsig(w3(1)*Sepal_Length(i) + w3(2)*Sepal_Width(i) + w3(3)*Petal_Length(i) + w3(4)*Petal_Width(i) - w3(5));
    if Y01>=0.5 && Y12<=0.5
        correct = correct + 1;
    end  
end
correct
for i =111:120
    Y01 = logsig(w1(1)*Sepal_Length(i) + w1(2)*Sepal_Width(i) + w1(3)*Petal_Length(i) + w1(4)*Petal_Width(i) - w1(5));
    Y02 = logsig(w2(1)*Sepal_Length(i) + w2(2)*Sepal_Width(i) + w2(3)*Petal_Length(i) + w2(4)*Petal_Width(i) - w2(5));
    Y12 = logsig(w3(1)*Sepal_Length(i) + w3(2)*Sepal_Width(i) + w3(3)*Petal_Length(i) + w3(4)*Petal_Width(i) - w3(5));
    if Y02>=0.5 && Y12>=0.5
        correct = correct + 1;
    end   
end
correct
fprintf("正确率为:%d\n",correct/30)

你可能感兴趣的:(模式识别相关,matlab,分类,机器学习)