outPut(1:3,1:3)=0; %判别矩阵的初始化
class1=[];
class2=[];
class3=[];
%生成二维正态分布的样本2 X N 维的矩阵 样本程序
%训练样本
load train1;
load train2;
load train3;
load test1;
load test2;
load test3;
% train1=mvnrnd([1 1],[4 0;0 5],100)'; %2 X N
% train2=mvnrnd([7 2],[7 0;0 4],100)';
% train3=mvnrnd([2 7],[2 0;0 4],100)';
% %测试样本
% test1=mvnrnd([1 1],[4 0;0 5],100)'; %2 X N
% test2=mvnrnd([7 2],[7 0;0 4],100)';
% test3=mvnrnd([2 7],[2 0;0 4],100)';
%---------------------------------------------------%
%先验概率
P(1)=length(train1)/(length(train1)+length(train2)+length(train3));
P(2)=length(train2)/(length(train1)+length(train2)+length(train3));
P(3)=length(train3)/(length(train1)+length(train2)+length(train3));
%计算相关量 cov(X):协方差矩阵 Ave:均值
%--------------------------------------------------------%
W1=-1/2*inv(cov(train1'));
W2=-1/2*inv(cov(train2'));
W3=-1/2*inv(cov(train3'));%
Ave1=(sum(train1')/length(train1))';%计算平均值(2维列向量,2*1)
Ave2=(sum(train2')/length(train2))';
Ave3=(sum(train3')/length(train3))';
w1=inv(cov(train1'))*Ave1;
w2=inv(cov(train2'))*Ave2;
w3=inv(cov(train3'))*Ave3;%2
w10=-1/2*Ave1'*inv(cov(train1'))*Ave1-1/2*log(det(cov(train1')))+log(P(1));
w20=-1/2*Ave2'*inv(cov(train2'))*Ave2-1/2*log(det(cov(train2')))+log(P(2));
w30=-1/2*Ave3'*inv(cov(train3'))*Ave3-1/2*log(det(cov(train3')))+log(P(3));
%-----------------------------------------------------------%
for i=1:3
for j=1:100
if i==1
g1=test1(:,j)'*W1*test1(:,j)+w1'*test1(:,j)+w10;
g2=test1(:,j)'*W2*test1(:,j)+w2'*test1(:,j)+w20;
g3=test1(:,j)'*W3*test1(:,j)+w3'*test1(:,j)+w30;
if g1>=g2&g1>=g3
outPut(1,1)=outPut(1,1)+1;
class1=[class1,test1(:,j)];
elseif g2>=g1&g2>=g3
outPut(1,2)=outPut(1,2)+1;%记录误判情况
class2=[class2,test1(:,j)];
else
outPut(1,3)=outPut(1,3)+1;%记录误判情况
class3=[class3,test1(:,j)];
end
elseif i==2
g1=test2(:,j)'*W1*test2(:,j)+w1'*test2(:,j)+w10;
g2=test2(:,j)'*W2*test2(:,j)+w2'*test2(:,j)+w20;
g3=test2(:,j)'*W3*test2(:,j)+w3'*test2(:,j)+w30;
if g2>=g1&g2>=g3
outPut(2,2)=outPut(2,2)+1;
class2=[class2,test2(:,j)];
elseif g1>=g2&g1>=g3
outPut(2,1)=outPut(2,1)+1;
class1=[class1,test2(:,j)];
else
outPut(2,3)=outPut(2,3)+1;
class3=[class3,test2(:,j)];
end
else
g1=test3(:,j)'*W1*test3(:,j)+w1'*test3(:,j)+w10;
g2=test3(:,j)'*W2*test3(:,j)+w2'*test3(:,j)+w20;
g3=test3(:,j)'*W3*test3(:,j)+w3'*test3(:,j)+w30;
if g3>=g1&g3>=g2
outPut(3,3)=outPut(3,3)+1;
class3=[class3,test3(:,j)];
elseif g2>=g1&g2>=g3
outPut(3,2)=outPut(3,2)+1;
class2=[class2,test3(:,j)];
else
outPut(3,1)=outPut(3,1)+1;
class1=[class1,test3(:,j)];
end
end
end
end
outPut
%---------------------------------------------------%
%画出各样本的分布情况
subplot(3,1,1)
plot(train1(1,:),train1(2,:),'go','LineWidth',2),hold on
plot(train2(1,:),train2(2,:),'b+','LineWidth',2),hold on
plot(train3(1,:),train3(2,:),'r.','LineWidth',2),hold on
title('训练样本分布情况')
legend('训练样本1','训练样本2','训练样本3')
subplot(3,1,2)
plot(test1(1,:),test1(2,:),'go','LineWidth',2),hold on
plot(test2(1,:),test2(2,:),'b+','LineWidth',2),hold on
plot(test3(1,:),test3(2,:),'r.','LineWidth',2),hold on
title('测试样本分布情况')
legend('测试样本1','测试样本2','测试样本3')
subplot(3,1,3)
plot(class1(1,:),class1(2,:),'go','LineWidth',2),hold on
plot(class2(1,:),class2(2,:),'b+','LineWidth',2),hold on
plot(class3(1,:),class3(2,:),'r.','LineWidth',2),hold on
title('测试样本分类后分布情况')
legend('测试样本1','测试样本2','测试样本3')