支持向量机(Support Vector Machine,SVM)最先由Cortes和Vapnik提出,它是一种有监督的模式识别方法。它的主要思想是建立一个分类决策面。SVM利用核函数将数据映射到高维空间,使其尽可能的线性可分。常用的核函数包括线性核函数、多项式核、径向基核(RBF)、傅里叶核、样条核和Sigmoid核函数等。通过比较这些核函数适用的数据特点,无论样本数据特点是高维还是低维,数据量大还是小,RBF核函数展现了很好的分类性能。因此,选择RBF作为SVM的分类核函数。
优化问题取决于两个重要参数c和g,这两个参数会影响SVM的预测性能。SVM预测问题取决于两个重要参数c和g,这两个参数会影响SVM的预测性能。为了提高模型的预测性能,引入网格式搜索法(GS)优化模型建立过程中的两个重要参数。同时避免模型过学习和欠学习的现象发生,采用5倍交叉验证法以训练集最小均方根误差为适应度函数来进行参数寻优。当达到最小均方根误差时,所得到的c和g为最佳参数。GS中,以0.5为间隔进行全局搜索,c和g的范围均是(2-10, 210)
SVM预测过程为:
(1)输入数据,规定训练输入、训练输出、预测输入和预测输出
(2)为加快网络收敛速度,进行数据归一化处理
(3)参数寻优,网格数搜索开始
(4)得到最优参数建立预测模型,避免模型过学习和欠学习的现象发生,采用5倍交叉验证法以训练集最小均方根误差为适应度函数来进行参数寻优。当达到最小均方根误差时,所得到的c和g为最佳参数。
(5)预测数据输入
(6)得出预测结果
部分代码
%% 清空环境变量 close all; clear all; clc; format compact; %% 数据的提取和预处理 data=xlsread('筛选后数据'); ts = data((1:320),1);%训练集输出 tsx = data((1:320),2:end);%训练集输入 tts=data((321:end),1);%预测集输出 ttx= data((321:end),2:end);%预测集输入 % 数据预处理,将原始数据进行归一化 ts = ts'; tsx = tsx'; tts=tts'; ttx=ttx'; % mapminmax为matlab自带的映射函数 % 对ts进行归一化 [TS,TSps] = mapminmax(ts,-1,1); %矢量归一化 [TTS,TTSps]= mapminmax(tts,-1,1); TS = TS'; TTS=TTS'; % mapminmax为matlab自带的映射函数 % 对tsx进行归一化 [TSX,TSXps] = mapminmax(tsx,-1,1); %特征值归一化 [TTX,TTXps] = mapminmax(ttx,-1,1); % 对TSX进行转置,以符合libsvm工具箱的数据格式要求 TSX = TSX'; TTX = TTX'; %% 选择回归预测分析最佳的SVM参数c&g % 进行参数选择: [bestmse,bestc,bestg] = SVMcgForRegress(TS,TSX,-10,10,-10,10); % 打印参数选择结果 disp('打印参数选择结果'); str = sprintf( 'Best Cross Validation MSE = %g Best c = %g Best g = %g',bestmse,bestc,bestg); disp(str); %% 利用回归预测分析最佳的参数进行SVM网络训练 cmd = ['-c ', num2str(bestc), ' -g ', num2str(bestg) , ' -s 3 -p 0.01']; model = svmtrain(TS,TSX,cmd); %% SVM网络回归预测 [predict,mse] = svmpredict(TS,TSX,model); [predict_2,mse_2] = svmpredict(TTS,TTX,model); predict = mapminmax('reverse',predict',TSps); predict_2 = mapminmax('reverse',predict_2',TTSps); predict = predict'; predict_2 =predict_2' % 均方根误差计算 N = length(tts); RMSE = sqrt((sum((tts-predict_2').^2))/N) % % 相关系数 % N = length(tts); % YUCE_R2 = (N*sum(predict_2'.*tts)-sum(predict_2)*sum(tts))^2/((N*sum((predict_2).^2)-(sum(predict_2'))^2)*(N*sum((tts).^2)-(sum(tts))^2)) %% 结果分析(测试集) figure; plot(tts,'-o'); hold on; plot(predict_2,'r-^'); legend('实际负荷','预测负荷'); hold off; title('SVM预测输出图','FontSize',12); xlabel('2019年11月20日-2019年12月30日','FontSize',12); ylabel('负荷(KW)','FontSize',12);
结果展示
编辑
编辑
微信公众号“matlab学习之家”