lstm多输入时间序列预测

lstm 多输入时间序列预测

clc
close all
clear all
%加载数据,重构为行向量
load data.mat
IN_train = data((1:4197),2:9)';
OUT_train = data((1:4197),10)';

% 测试集――10个样本
IN_test = data((4198:end),2:9)';
OUT_test = data((4198:end),10)';
N = size(IN_test,2);
[in_train, ps_input] = mapminmax(IN_train,0,1);
in_test = mapminmax('apply',IN_test,ps_input);
[out_train, ps_output] = mapminmax(OUT_train,0,1);

xvalidation = in_train(:,3498:end);
in_train = in_train(:,1:3497);
yvalidation = out_train(:,3498:end);
out_train=out_train(:,1:3497);
rng('default')%设置随机种子
%%
%创建LSTM回归网络,指定LSTM层的隐含单元个数96*3
%序列预测,因此,输入一维,输出一维
%rng('default')
numFeatures = size(in_train,1);
numResponses = 1;
numHiddenUnits =138;
 
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits)
    fullyConnectedLayer(numResponses)
    regressionLayer];
 
%指定训练选项,求解器设置为adam, 250 轮训练。
%1梯度阈值设置为 1。指定初始学习率 0.005,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
%2梯度阈值设置为 1。指定初始学习率 0.005,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
options = trainingOptions('adam', ...
    'MaxEpochs',100, ...
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',55, ...
    'ValidationData',{xvalidation,yvalidation}, ...
    'ValidationFrequency',10, ...
    'LearnRateDropFactor',0.2, ...
    'Verbose',0, ...
    'Plots','training-progress');
%训练LSTM
net = trainNetwork(in_train,out_train,layers,options);
net = resetState(net);
net = predictAndUpdateState(net,in_train);
YPred = predict(net,in_test,'MiniBatchSize',1);
%%
% 5. 数据反归一化
T_sim = mapminmax('reverse',YPred,ps_output);


%% V. 性能评价
%%
% 1. 相对绝对误差MAE
mae = sum(abs(T_sim - OUT_test))./N;
% 2.均方根误差RMSE
rmse = sqrt(sum((T_sim - OUT_test).^2)./N);
%%
% 2. 相关系数R
R = sqrt((N * sum(T_sim .* OUT_test) - sum(T_sim) * sum(OUT_test))^2 / ((N * sum((T_sim).^2) - (sum(T_sim))^2) * (N * sum((OUT_test).^2) - (sum(OUT_test))^2))); 
%%
% 3. 结果对比
result = [OUT_test' T_sim']
%% VI. 绘图
figure
plot(1:N,OUT_test,'b',1:N,T_sim,'r')
legend('真实值','预测值')
xlabel('预测样本')
ylabel('真实值')
string = {'LSTM预测结果对比';['R=' num2str(R)]};
title(string)

你可能感兴趣的:(matlab,lstm,深度学习,神经网络)