本文将介绍hinge loss E(w) 以及其梯度 ∇E(w) 。并利用批量梯度下降方法来优化hinge loss实现SVM多分类。利用hinge loss在手写字数据库上实验,能达到87.040%的正确识别率。
其中 T={(x1,y1),…,(xn,yn)} 为训练集。 L((w1,…,wk),(x,y))=max(0,maxy′≠ywTy′x+1−wTyx) . 二分类SVM转化为多分类SVM的相关资料和公式推导可以参见其他文献。
2. 接下介绍 E(w) 的梯度计算。
(a) 如果 wTy≥wTy^x+1 , 那么
(b) 如果 wTy<wTy^x+1 和 j=y , 那么
(c) 如果 wTy<wTy^x+1 和 j=y^ , 那么
(d) 如果 wTy<wTy^x+1 和 j≠y and j≠y^ , 那么
Muliticlass_svm.m
% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
clear all
clc
%% STEP 0: Initialise constants and parameters
inputSize = 28 * 28; % Size of input vector (MNIST images are 28x28)
numClasses = 10; % Number of classes (MNIST images fall into 10 classes)
lambda = 1e-2; % Weight decay parameter
learning_rate = 0.1;
iteration=400;
%%======================================================================
%% STEP 1: Load data
load('digits.mat')
images = [train1; train2; train3; train4; train5; train6; train7; train8; train9;train0];
images = images';
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];
index = randperm(500*10);
images = images(:,index);
labels = labels(index);
inputData = images;
%% STEP 2: Train multiclass svm
[cost, grad, svmOptTheta] = multisvmtrain(numClasses, inputSize, lambda, inputData, labels, iteration, learning_rate);
%% STEP 3: Test
images = [test1; test2; test3; test4; test5; test6; test7; test8; test9;test0];
images = images';
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];
inputData = images;
svmModel.optTheta = reshape(svmOptTheta, numClasses, inputSize);
svmModel.inputSize = inputSize;
svmModel.numClasses = numClasses;
% You will have to implement softmaxPredict in softmaxPredict.m
[pred] = Multi_SVMPredict(svmModel, inputData);
acc = mean(labels(:) == pred(:));
num_in_class = 500*ones(10,1)';
for i=1:10
name_class{i}=num2str(i);
end
[confusion_matrix]=compute_confusion_matrix(pred,num_in_class,name_class);
figure; visualize(svmOptTheta');
fprintf('Accuracy: %0.3f%%\n', acc * 100);
multisvmtrain.m
% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [lcost, grad, theta] = multisvmtrain(numClasses, inputSize, lambda, data, labels, iteration, learning_rate)
theta = 0.005 * randn(numClasses * inputSize, 1);
theta = reshape(theta, numClasses, inputSize);%将输入的参数列向量变成一个矩阵
numCases = size(data, 2);%输入样本的个数
groundTruth = full(sparse(labels, 1:numCases, 1));%这里sparse是生成一个稀疏矩阵,该矩阵中的值都是第三个值1
cost = 0;
thetagrad = zeros(numClasses, inputSize);
for i = 1:iteration
[Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda);
[thetagrad] = multi_hingeloss_grad(data,theta, Q, groundTruth, lambda, labels);
theta = theta - learning_rate*thetagrad;
lcost(i) = cost;
grad(i) = sum(sum(thetagrad));
fprintf('%d, %f\n', i, cost);
end
end
multi_hingeloss_cost.m
% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda)
groundTruth1 = groundTruth;
groundTruth(find(groundTruth==1)) = -inf;
groundTruth(find(groundTruth==0)) = 1;
X = theta*data;
Q = X;
Q = Q.*groundTruth;
Q(find(Q==inf)) = -inf;
temp = X.*groundTruth1;
temp(find(temp==0))=[];
t = max(0, 1 - temp + max(Q));
cost = 1/size(data,2)*sum(t)+lambda*sum(theta(:).^2);
multi_hingeloss_grad.m
% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [thetagrad] = multi_hingeloss_grad(data, theta, Q, groundTruth, lambda, labels)
X = theta*data;
[~,q] = max(Q);
Xq = full(sparse(q, 1:size(X,2), 1));
if size(Xq,1)<10
for i = 1:10-size(Xq,1)
Xq = [Xq;zeros(1, size(Xq,2))];
end
end
temp = X.*groundTruth;
temp1 = X.*Xq;
temp1(find(temp1==0))=[];
temp(find(temp==0))=[];
W=(temp - temp1)<1;
Y = zeros(size(X));
for i=1:size(X,2)
Y(labels(i),i) = -W(i);
Y(q(i),i) = W(i);
end
thetagrad = 1/size(X,2)*Y*data' + lambda * theta;
Multi_SVMPredict.m
% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [pred] = Multi_SVMPredict(svmModel, data)
theta = svmModel.optTheta; % this provides a numClasses x inputSize matrix
pred = zeros(1, size(data, 2));
[nop, pred] = max(theta * data);
end
compute_confusion_matrix.m
[confusion_matrix]=compute_confusion_matrix(predict_label,num_in_class,name_class)%预测标签,每一类的数目,类别数目
%predict_label为一维行向量
%num_in_class代表每一类的个数
%name_class代表类名
num_class=length(num_in_class);
num_in_class=[0 num_in_class];
confusion_matrix=size(num_class,num_class);
for ci=1:num_class
for cj=1:num_class
summer=0;%统计对应标签个数
c_start=sum(num_in_class(1:ci))+1;
c_end=sum(num_in_class(1:ci+1));
summer=size(find(predict_label(c_start:c_end)==cj),2);
confusion_matrix(ci,cj)=summer/num_in_class(ci+1);
end
end
draw_cm(confusion_matrix,name_class,num_class);
end
function draw_cm.m
function draw_cm(mat,tick,num_class)
imagesc(1:num_class,1:num_class,mat); %# in color
colormap(flipud(gray)); %# for gray; black for large value.
textStrings = num2str(mat(:),'%0.2f');
textStrings = strtrim(cellstr(textStrings));
[x,y] = meshgrid(1:num_class);
hStrings = text(x(:),y(:),textStrings(:), 'HorizontalAlignment','center');
midValue = mean(get(gca,'CLim'));
textColors = repmat(mat(:) > midValue,1,3);
set(hStrings,{'Color'},num2cell(textColors,2)); %# Change the text colors
set(gca,'xticklabel',tick,'XAxisLocation','top');
set(gca, 'XTick', 1:num_class, 'YTick', 1:num_class);
set(gca,'yticklabel',tick);
rotateXLabels(gca, 315 );% rotate the x tick
visualize.m
function r=visualize(X, mm, s1, s2)
%FROM RBMLIB http://code.google.com/p/matrbm/
%Visualize weights X. If the function is called as a void method,
%it does the plotting. But if the function is assigned to a variable
%outside of this code, the formed image is returned instead.
if ~exist('mm','var')
mm = [min(X(:)) max(X(:))];
end
if ~exist('s1','var')
s1 = 0;
end
if ~exist('s2','var')
s2 = 0;
end
[D,N]= size(X);
s=sqrt(D);
if s==floor(s) || (s1 ~=0 && s2 ~=0)
if (s1 ==0 || s2 ==0)
s1 = s; s2 = s;
end
%its a square, so data is probably an image
num=ceil(sqrt(N));
a=mm(2)*ones(num*s2+num-1,num*s1+num-1);
x=0;
y=0;
for i=1:N
im = reshape(X(:,i),s1,s2)';
a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im;
x=x+1;
if(x>=num)
x=0;
y=y+1;
end
end
d=true;
else
%there is not much we can do
a=X;
end
%return the image, or plot the image
if nargout==1
r=a;
else
imagesc(a, [mm(1) mm(2)]);
axis equal
colormap gray
end
得到的识别率为87.040%,hinge loss可以和任何深度网络结合完成分类任务。
最后得到的混淆矩阵如下:
数据集见资源,如引用此代码,请注明出处。