网上有很多写的好的博客讲解线性分类和SVM,本人讲解能力差,就给个链接。
http://blog.csdn.net/mm_bit/article/details/46988925
训练svm分类器实际上是解二次规划问题,matlab里用到的是quadprog函数,其使用用法可参见matlab官方文档:
http://cn.mathworks.com/help/optim/ug/quadprog.html
或者不想读英文文档的人可以看别人写的博客:
http://blog.csdn.net/jbb0523/article/details/50598641
会使用quadprog函数基本上就会写svm分类器了,这里贴上源码:
svmTrain.m
function [ svm ] = svmTrain( trainData,trainLabel,kertype,C )
options=optimset;
options.LargerScale='off';
options.Display='off';
n=length(trainLabel);
H=(trainLabel'*trainLabel).*kernel(trainData,trainData,kertype);
f=-ones(n,1);
A=[];
b=[];
Aeq=trainLabel;
beq=0;
lb=zeros(n,1);
ub=C*ones(n,1);
a0=zeros(n,1);
[a,fval,eXitflag,output,lambda]=quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
epsilon=1e-8;
sv_label=find(abs(a)>epsilon);
svm.a=a(sv_label);
svm.Xsv=trainData(:,sv_label);
svm.Ysv=trainLabel(sv_label);
svm.svnum=length(sv_label);
end
kernel.m(更新)
function K = kernel( X,Y,type )
switch type
case 'linear'
K=X'*Y;
case 'rbf'
delta=5;
delta=delta*delta;
XX=sum(X'.*X',2);
YY=sum(Y'.*Y',2);
XY=X'.*Y;
K=abs(repmat(XX,[1 size(YY,1)])+repmat(YY',[size(XX,1) 1])-2*XY);
K=exp(-K./delta);
end
end
svmTest.m:
function result = svmTest(svm, Xt, Yt, kertype)
temp = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,svm.Xsv,kertype);
%total_b = svm.Ysv-temp;
b = mean(svm.Ysv-temp); %b取均值
w = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,Xt,kertype);
result.score = w + b;
Y = sign(w+b); %f(x)
result.Y = Y;
result.accuracy = size(find(Y==Yt))/size(Yt);
end
test.m
%------------主函数----------------
clc;
clear;
C = 10; %成本约束参数
kertype = 'linear'; %线性核
%①------数据准备
n = 30;
%randn('state',6); %指定状态,一般可以不用
x1 = randn(2,n); %2行N列矩阵,元素服从正态分布
y1 = ones(1,n); %1*N个1
x2 = 4+randn(2,n); %2*N矩阵,元素服从正态分布且均值为5,测试高斯核可x2 = 3+randn(2,n);
y2 = -ones(1,n); %1*N个-1
figure; %创建一个用来显示图形输出的一个窗口对象
plot(x1(1,:),x1(2,:),'bs',x2(1,:),x2(2,:),'k+'); %画图,两堆点
axis([-3 8 -3 8]); %设置坐标轴范围
hold on; %在同一个figure中画几幅图时,用此句
%②-------------训练样本
X = [x1,x2]; %训练样本2*n矩阵,n为样本个数,d为特征向量个数
Y = [y1,y2]; %训练目标1*n矩阵,n为样本个数,值为+1或-1
svm = svmTrain(X,Y,kertype,C); %训练样本
plot(svm.Xsv(1,:),svm.Xsv(2,:),'ro'); %把支持向量标出来
%③-------------测试
[x1,x2] = meshgrid(-2:0.05:7,-2:0.05:7); %x1和x2都是181*181的矩阵
[rows,cols] = size(x1);
nt = rows*cols;
Xt = [reshape(x1,1,nt);reshape(x2,1,nt)];
%前半句reshape(x1,1,nt)是将x1转成1*(181*181)的矩阵,所以xt是2*(181*181)的矩阵
%reshape函数重新调整矩阵的行、列、维数
Yt = ones(1,nt);
result = svmTest(svm, Xt, Yt, kertype);
%④--------------画曲线的等高线图
Yd = reshape(result.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
title('svm分类结果图');
x1=xlabel('X轴');
x2=ylabel('Y轴');
多类线性问题的一种解法是将多类分为多个两类分类器。
如训练集有1,2,3,4个类,让1与2作两类分类器训练得到12分类器,1与3得到13分类器,等等两两训练,得到(4-1)*4/2 6个分类器,再将测试数据用这6个分类器一一测试,如果对12,13,14三个都是正的,则该类属于1类。
代码:
multiLiner:
clc;
clear;
C=10;
kertype='linear';
%生成测试数据
n=30;
x1=randn(2,n);
x2=4+randn(2,n);
x3=randn(2,n);
x3=[x3(1,:)+4;x3(2,:)-4];
x4=randn(2,n);
x4=[x4(1,:)+8;x4(2,:)];
%可视化生成数据
plot(x1(1,:),x1(2,:),'bs',x2(1,:),x2(2,:),'k+');
hold on;
plot(x3(1,:),x3(2,:),'r*',x4(1,:),x4(2,:),'y.');
axis([-3 11 -7 7]);
hold on;
%两两合成一个训练组训练模型
trainData12=[x1,x2];
trainData13=[x1,x3];
trainData14=[x1,x4];
trainData23=[x2,x3];
trainData24=[x2,x4];
trainData34=[x3,x4];
trainLabel=[ones(1,n),-ones(1,n)];
svm_12=svmTrain(trainData12,trainLabel,kertype,C);
svm_13=svmTrain(trainData13,trainLabel,kertype,C);
svm_14=svmTrain(trainData14,trainLabel,kertype,C);
svm_23=svmTrain(trainData23,trainLabel,kertype,C);
svm_24=svmTrain(trainData24,trainLabel,kertype,C);
svm_34=svmTrain(trainData34,trainLabel,kertype,C);
%生成测试数据
[x1,x2] = meshgrid(-2:0.05:10,-6:0.05:6); %x1和x2都是181*181的矩阵
[rows,cols] = size(x1);
nt = rows*cols;
Xt = [reshape(x1,1,nt);reshape(x2,1,nt)];
%前半句reshape(x1,1,nt)是将x1转成1*(181*181)的矩阵,所以xt是2*(181*181)的矩阵
%reshape函数重新调整矩阵的行、列、维数
Yt = ones(1,nt);
result12=svmTest(svm_12,Xt,Yt,kertype);
Yd = reshape(result12.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
hold on;
result13=svmTest(svm_13,Xt,Yt,kertype);
Yd = reshape(result13.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
hold on;
result14=svmTest(svm_14,Xt,Yt,kertype);
Yd = reshape(result14.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
hold on;
result23=svmTest(svm_23,Xt,Yt,kertype);
Yd = reshape(result23.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
hold on;
result24=svmTest(svm_24,Xt,Yt,kertype);
Yd = reshape(result24.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
hold on;
result34=svmTest(svm_34,Xt,Yt,kertype);
Yd = reshape(result34.Y,rows,cols);
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线
%测试一个样本点属于哪一类
Xt=[10;2];
Yt=1;
result12=svmTest(svm_12,Xt,Yt,kertype);
result13=svmTest(svm_13,Xt,Yt,kertype);
result14=svmTest(svm_14,Xt,Yt,kertype);
result23=svmTest(svm_23,Xt,Yt,kertype);
result24=svmTest(svm_24,Xt,Yt,kertype);
result34=svmTest(svm_34,Xt,Yt,kertype);
if result12.Y==1&&result13.Y==1&&result14.Y==1
testLabel=1;
elseif result12.Y==-1&&result23.Y==1&&result24.Y==1
testLabel=2;
elseif result13.Y==-1&&result23.Y==-1&&result34.Y==1
testLabel=3;
elseif result14.Y==-1&&result24.Y==-1&&result34.Y==-1
testLabel=4;
else
testLabel=-1;
disp('测试点不属于这4类中');
end