交叉验证(Cross-Validation)的基本思想:
将原数据进行分组,一部分做为训练集,另一部分做为验证集,首先用训练集对不同参数的模型进行训练,再利用验证集来测试训练得到的模型,进而用验证集的测试误差来衡量模型中的参数。
常用的交叉验证的方法:
( 1) Hold-out 方法
Hold-out 方法即将原数据随机的分成两组,一组用作训练数据集,另一组用作验证数据集。
训练集训练模型,验证集则用于验证训练得到的模型,验证集的测试误差则为衡量标准。
Hold-out 方法依赖于单一的数据分割,并没有出现数据的交叉。实验结果高度依赖数据集的分割,验证结果容易出现不稳定的情况。为此,取多次 Hold-out 验证结果的均值则成为广义的交叉验证方法的评价标准。但在多次采用 Hold-out 验证时,由于随机分割数据,可能造成部分数据的信息无法充分利用,进而造成实验效果的偏差。
( 2)留一验证法( Leave-one-out, LOO CV)
LOO CV是经典交叉验证方法的一种,
即若原数据为 N 个样本,则每个数据单独作为验证集,剩下的 N-1个样本作为训练集。实验数据在训练时的充分利用是 LOOCV 方法的优点,但是在数据比较大时,高成本的计算复杂度成为 LOO CV 方法的局限性,然而小样本数据在采用 LOO CV 方法时可以得到很好的效果。
LOO CV 方法可以进一步得到改进,形成 LMO (Leave-m-out) CV[40],即数据中每 M个数据均有机会成为验证集,剩余的 − MN 个数据作为训练集。 M > 1时, LMO CV 在一定程度上改善了 LOO CV 计算复杂度过高的不足。
( 3) K 折交叉验证法( K-flod Cross Validation)
在上述方法的基础上,提出了 K 折交叉验证法,即将原数据平均分成 K 组,每一组均可作为验证集,剩余的 K −1组作为训练集。 K 个验证集的测试误差的均值为此方法的性能评价指标。 K 折交叉验证法不仅充分利用了数据的信息,有效的避免了过拟合和欠拟合状态的发生,得到的结果具有说服性,而且降低了计算复杂度。这一优势使得 K折交叉验证法成为最常用的交叉验证方法。然而 K 值的选取问题,困扰着 K 折交叉验证法的使用。一般而言, K 的选取范围为 5 到 10。
MATLAB代码例子:
function [bestacc,bestc,bestg] = SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep) %SVMcg cross validation by faruto if nargin < 10 accstep = 4.5; end if nargin < 8 cstep = 0.8; gstep = 0.8; end if nargin < 7 v = 5; end if nargin < 5 gmax = 8; gmin = -8; end if nargin < 3 cmax = 8; cmin = -8; end % X:c Y:g cg:CVaccuracy [X,Y] = meshgrid(cmin:cstep:cmax,gmin:gstep:gmax); [m,n] = size(X); cg = zeros(m,n); eps = 10^(-4); % record acc with different c & g,and find the bestacc with the smallest c bestc = 1; bestg = 0.1; bestacc = 0; basenum = 2; for i = 1:m for j = 1:n cmd = ['-v ',num2str(v),' -c ',num2str( basenum^X(i,j) ),' -g ',num2str( basenum^Y(i,j) )]; cg(i,j) = svmtrain(train_label, train, cmd); if cg(i,j) <= 55 continue; end if cg(i,j) > bestacc bestacc = cg(i,j); bestc = basenum^X(i,j); bestg = basenum^Y(i,j); end if abs( cg(i,j)-bestacc )<=eps && bestc > basenum^X(i,j) bestacc = cg(i,j); bestc = basenum^X(i,j); bestg = basenum^Y(i,j); end end end % to draw the acc with different c & g figure; [C,h] = contour(X,Y,cg,70:accstep:100); clabel(C,h,'Color','r'); xlabel('log2c','FontSize',12); ylabel('log2g','FontSize',12); firstline = 'SVC参数选择结果图(等高线图)[GridSearchMethod]'; secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ... ' CVAccuracy=',num2str(bestacc),'%']; title({firstline;secondline},'Fontsize',12); grid on; figure; meshc(X,Y,cg); % mesh(X,Y,cg); % surf(X,Y,cg); axis([cmin,cmax,gmin,gmax,30,100]); xlabel('log2c','FontSize',12); ylabel('log2g','FontSize',12); zlabel('Accuracy(%)','FontSize',12); firstline = 'SVC参数选择结果图(3D视图)[GridSearchMethod]'; secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ... ' CVAccuracy=',num2str(bestacc),'%']; title({firstline;secondline},'Fontsize',12);