引言:LIBSVM是台湾大学林智仁(Lin Chih-Jen)教授等开发设计的一个简单、易于使用和快速有效的SVM模式识别与回归的软件包,他不但提供了编译好的可在Windows系列系统的执行文件,还提供了源代码,方便改进、修改以及在其它操作系统上应用;该软件对SVM所涉及的参数调节相对比较少,提供了很多的默认参数,利用这些默认参数可以解决很多问题;并提供了交互检验(Cross Validation)的功能。该软件可以解决C-SVM、ν-SVM、ε-SVR和ν-SVR等问题,包括基于一对一算法的多类模式识别问题。
我们在进行科学研究的时候会经常使用SVM对数据进行分类,MATLAB自带的SVM函数调参麻烦,且只支持分类问题,不支持回归问题。因此,林教授开发的功能更为强大的LIBSVM就是我们的不二选择。
LIBSVM支持MATLAB、Python、C等编译语言,今天我将讲解的是在MATLAB环境下调用LIBSVM。
其中:LIBSVM工具包下载:https://www.csie.ntu.edu.tw/~cjlin/libsvm/,具体安装过程这里就不详细介绍了,大家可以参考其他博客。
模型训练:model = svmtrain(label,data,'libsvm_options');
模型预测:[predicted_label,accuary] = svmpredict(label_test,data_test,model,'libsvm_options')
其中libsvm_options为可选参数,其具体内容如下:
-s 设置svm类型:
0 – C-SVC
1 – v-SVC
2 – one-class-SVM
3 – ε-SVR
4 – n – SVR
-t 设置核函数类型, 默认值为2
0 — 线性核: μ‘∗ν
1 — 多项式核: (γ∗μ‘∗ν+coef0)degree
2 — RBF核: exp(–γ∗∥μ−ν∥2)
3 — sigmoid 核: tanh(γ∗μ‘∗ν+coef0)
-d degree: 核函数中的degree设置(针对多项式核函数)(默认3);
-g r(gama): 核函数中的gamma函数设置(针对多项式/rbf/sigmoid核函数)(默认1/ k);
-r coef0: 核函数中的coef0设置(针对多项式/sigmoid核函数)((默认0);
-c cost: 设置C-SVC, e -SVR和v-SVR的参数(损失函数)(默认1);
-n nu: 设置v-SVC, 一类SVM和v- SVR的参数(默认0.5);
-p p: 设置e -SVR 中损失函数p的值(默认0.1);
-m cachesize: 设置cache内存大小, 以MB为单位(默认40);
-e eps: 设置允许的终止判据(默认0.001);
-h shrinking: 是否使用启发式, 0或1(默认1);
-wi weight: 设置第几类的参数C为weight*C (C-SVC中的C) (默认1);
-v n: n-fold交互检验模式, n为fold的个数, 必须大于等于2;
-b 概率估计: 是否计算SVC或SVR的概率估计, 可选值0或1, 默认0;
例:
分类问题:
model = svmtrain(label_train,data_train,'-s 0 -t 2 -c 0.1 -g 0.1');
[predicted_label,accuray] = svmpredict(label_test,data_test,model)
回归问题:
model = svmtrain(label_train,data_train,'-s 3 -t 2-c 0.1 -g 0.1 -p 0.01')
predicted_label = svmpredict(label_test,data_test,model)
针对SVM中的参数优化,Python环境下的LIBSVM中有寻优函数grid.py帮助大家寻找最优的c和g:
但是MATLAB环境下的LIBSVM却没有这个功能,因此今天这里就给大家分享在MATLAB环境下实现LIBSVM参数c和g的自动寻优:
function [best_c,best_g,best_acc] = SvmSearchParas(data,label,c_max,c_min,c_step,g_max,g_min,g_step,v)
%--------------------------------------------------------------------------
%The function looks for the SVM's most important parameters c and g
%The Author:等等登登-Ande
%The Email:[email protected]
%The Blog:qq_35166974
%%
%Initialization parameter
if nargin < 9
v = 10;
end
if nargin < 8
v = 10;
g_step = 1;
end
if nargin < 7
v = 10;
g_step = 1;
c_step = 1;
end
if nargin < 6
v = 10;
g_step = 1;
c_step = 1;
g_min = -5;
end
if nargin < 5
v = 10;
g_step = 1;
c_step = 1;
g_min = -5;
g_max = 5;
end
if nargin < 4
v = 10;
g_step = 1;
c_step = 1;
g_min = -5;
g_max = 5;
c_min = -5;
end
if nargin < 3
v = 10;
g_step = 1;
c_step = 1;
g_min = -5;
g_max = 5;
c_min = -5;
c_max = 5;
end
if nargin < 2
warning('You did not enter enough parameters!');
end
%%
%Parameter optimization
[mesh1,mesh2] = meshgrid(c_min:c_step:c_max,g_min:g_step:g_max);
[raw,col] = size(mesh1);
acc = zeros(raw,col);
for i=1:raw
for j=1:col
cg_paras = ['-v ',num2str(v),'-c ',num2str(2.^mesh1(i,j)),' ','-g ',num2str(2.^mesh2(i,j))];
acc(i,j) = libsvmtrain(double(label),double(data),cg_paras);
end
end
best_acc = max(max(acc));
[label_i,label_j] = find(acc==best_acc);
best_c = 2.^mesh1(label_i,label_j);
best_g = 2.^mesh2(label_i,label_j);
figure
mesh(mesh1,mesh2,acc);
xlabel('log2c');
ylabel('log2g');
zlabel('Accuracy')