使用x = quadprog(H,f,A,b,Aeq,beq,lb,ub)训练Lagrange multipliers αi and αi’,并用于回归分析。
Training
function [alphas, alphas_dash, d] = svrTrain( X, y, K, epsilon, C )
% Input
% -----
%
% X ... Data points.
% [ x_11, x_12;
% x_21, x_22;
% x_31, x_32;
% ... ]
%
% y ... Class labels.
% [ s_1; s_2; s_3; ... ]
%
% K ... Kernel.
% @(x, y) ...
%
% epsilon ... SVR parameter.
%
%
% C ... SVR parameter.
% Output
% ------
%
% alphas ... Lagrange multipliers.
%
% alphas_dash... Lagrange multipliers.
%
% d ... Distance from the origin.
% Obtain the size of data
n = size(X, 1);
% Initialization quadprog vars
A = [];
b = [];
Aeq = [];
f = [];
ub = []
lb = []
%Aqe,ub,lu,f
for i = 1:n
Aeq = [Aeq,1,-1]
ub = [ub,C,C]
lb = [lb,0,0]
f = [f,-epsilon + y(i),-epsilon - y(i)];
end
%initialize Q
H = zeros(2*n,2*n)
for i = 1:2*n
for j = 1:2*n
if mod(i,2) == 0
xi = X(i/2,:)
else
xi = X((i+1)/2,:)
end
if mod(j,2) == 0
xj = X(j/2,:)
else
xj = X((j+1)/2,:)
end
if mod((i + j),2) == 0
H(i,j) = 1 * K(xi,xj)
else
H(i,j) = -1 * K(xi,xj)
end
end
end
%max -> min
f = -f
beq = 0;
[alphalist] = quadprog(H, f, A, b, Aeq, beq, lb, ub);
%alphalist:[alphas1,alphas1',alphas2,alphas2'...]
za = size(alphalist);
alphas = [];
alphas_dash = [];
%seperate alphas1,alphas1'
for i = 1:za(1)
if mod(i,2) == 1
alphas = [alphas;alphalist(i)]
else
alphas_dash = [alphas_dash;alphalist(i)]
end
end
%record j and where it from
j = 0
isDash = 0
for i = 1:n
if(alphas(i)>0&&alphas(i)0&&alphas_dash(i)
Produce
function f = svrProduce( X, K, alphas, alphas_dash, d, x1, x2 )
% Input
% -----
%
% X ... Data points.
% [ x_11, x_12;
% x_21, x_22;
% x_31, x_32;
% ... ]
%
% K ... Kernel.
% @(x, y) ...
%
% alphas ... Lagrange multipliers.
%
% alphas_dash... Lagrange multipliers.
%
% d ... Distance from the origin.
%
% x1 ... Domain of x1, e.g. [-3 -2 -1 0 1 2 3].
%
% x2 ... Domain of x2, e.g. [-3 -2 -1 0 1 2 3].
% Output
% ------
%
% f ... Approximated values of f(x) on the domain(s) of x1 and x2.
% Initialization
size1 = size(x1,2);
size2 = size(x2,2);
n = size(X, 1);
f = zeros(size1,1);
%calculate the predicted value of 1d or 2d data set
if(size2 == 0)
for d1 = 1:size1
temp = 0;
z = x1(d1) ;
for i = 1:n
temp = temp + (alphas(i,1) - alphas_dash(i,1)) * K(X(i,:), z);
end
f(d1) = temp + d;
end
else
for d1 = 1:size1
for d2 = 1:size2
temp = 0;
z = [x1(d1),x2(d2)];
for i = 1:n
temp = temp + (alphas(i,1) - alphas_dash(i,1)) * K(X(i,:), z);
end
f(d1,d2) = temp + d;
end
end
end
end
1D data Test
clear all;
close all;
clc;
% --------------------------------------------------
% Test instance 1-D data.
% --------------------------------------------------
nSamples = 11;
X = linspace(-16, 16, nSamples)';
noiseFactor = 20;
y = ((X-5).^2 + 2 + noiseFactor * (rand(size(X)) - 0.5))';
% Set up kernel.
K = @(x, y) (x*y'+1)^2;
% Set up parameters.
epsilon = 2
C = 10;
% Call dual QP with kernel.
[alphas, alphas_dash, d] = svrTrain(X, y, K, epsilon, C);
% Calculate regression line / regression plane.
x1 = -16:0.05:16;
f = svrProduce(X, K, alphas, alphas_dash, d, x1, []);
% Plot results.
svrPlot(x1, [], (x1-5).^2 + 2, X, y, f);
text(5, 450, 'K(x,y) = (x*y+1)^2', 'FontSize', 14, 'FontWeight', 'bold');
text(5, 400, ['Noise factor: ' num2str(noiseFactor)], 'FontSize', 14, 'FontWeight', 'bold');
text(5, 350, ['nSamples: ' num2str(nSamples)], 'FontSize', 14, 'FontWeight', 'bold');
text(5, 300, ['Sampling interval: ' num2str(X(1)) ' - ' num2str(X(end))], 'FontSize', 14, 'FontWeight', 'bold');
print('-dpng', '1D.png', '-r150');
2D data test
clear all;
close all;
clc;
% --------------------------------------------------
% Test instance 2-D data.
% --------------------------------------------------
x1 = -3:0.1:3;
x2 = -3:0.1:3;
[X1, X2] = meshgrid(x1, x2);
mu = [0, 0];
sigma = eye(2);
F = mvnpdf([X1(:) X2(:)], mu, sigma);
F = reshape(F, length(x1), length(x2));
nSamples = 40;
X = zeros(nSamples, 2);
y = zeros(1, nSamples);
for i = 1:nSamples
xi = ceil(rand() * length(x1));
yi = ceil(rand() * length(x2));
X(i, 1) = x1(xi);
X(i, 2) = x2(yi);
y(i) = F(xi, yi);
end
% Set up kernel.
K = @(x, y) exp(-norm(x-y)^2);
% Set up parameters.
epsilon = 0.00002;
C = 7;
% Call dual QP with kernel.
[alphas, alphas_dash, d] = svrTrain(X, y, K, epsilon, C);
% Calculate regression line / regression plane.
x1 = -3:0.1:3;
x2 = -3:0.1:3;
f = svrProduce(X, K, alphas, alphas_dash, d, x1, x2);
% Plot results.
svrPlot(x1, x2, F, X, y, f);
text(0, 0, 0.21, 'K(x,y) = exp(-norm(x-y)^2)', 'FontSize', 14, 'FontWeight', 'bold');
text(0, 0, 0.20, ['nSamples: ' num2str(nSamples)], 'FontSize', 14, 'FontWeight', 'bold');
print('-dpng', '2D.png', '-r150');
plot figure tool
function svrPlot( x1, x2, F, X, y, f )
figure();
set(gcf, 'Units', 'normalized', 'OuterPosition', [0.05 0.05 0.9 0.9]);
set(gcf, 'PaperOrientation', 'landscape');
set(gcf, 'PaperUnits', 'centimeters', 'PaperPosition', [0 0 29.7 21]);
set(gcf, 'PaperSize', [29.7 21.0]);
if (isempty(x2))
plot(x1, F, 'k--', 'LineWidth', 2); hold on;
plot(X, y, 'k*', 'MarkerSize', 15);
plot(x1, f, 'k-', 'LineWidth', 2); hold off;
xlabel('x', 'FontSize', 14, 'FontWeight', 'bold');
ylabel('f(x)', 'FontSize', 14, 'FontWeight', 'bold');
legend('f(x)', 'Sample points', 'Regression line', 'Location', 'SouthEast');
set(gca, 'FontSize', 14, 'FontWeight', 'bold');
xlim([-18 18]);
ylim([-100 500]);
else
subplot(1, 2, 1);
surf(x1, x2, F); hold on;
plot3(X(:, 1), X(:, 2), y, 'r*', 'MarkerSize', 17, 'LineWidth', 2); hold off;
view([-45 15]);
xlabel('x_1', 'FontSize', 14, 'FontWeight', 'bold');
ylabel('x_2', 'FontSize', 14, 'FontWeight', 'bold');
zlabel('f(x_1, x_2)', 'FontSize', 14, 'FontWeight', 'bold');
legend('f(x_1, x_2)', 'Sample points', 'Location', 'NorthEast');
set(gca, 'FontSize', 14, 'FontWeight', 'bold');
xlim([-5 5]);
ylim([-5 5]);
zlim([0 0.2]);
grid on;
subplot(1, 2, 2);
surf(x1, x2, f);
view([-45 15]);
xlabel('x_1', 'FontSize', 14, 'FontWeight', 'bold');
ylabel('x_2', 'FontSize', 14, 'FontWeight', 'bold');
zlabel('f(x_1, x_2)', 'FontSize', 14, 'FontWeight', 'bold');
legend('Regression plane', 'Location', 'NorthEast');
set(gca, 'FontSize', 14, 'FontWeight', 'bold');
xlim([-5 5]);
ylim([-5 5]);
zlim([0 0.2]);
grid on;
end
RMS = mean((F(:)-f(:)).^2);
title(['RMS: ' num2str(RMS)], 'FontSize', 14, 'FontWeight', 'bold');
end