MATLAB中LSTM时序分类的用法与实战

MATLAB中LSTM时序分类的用法与实战

  • 说明

本教程适用于R2018b版本的matlab(不知道R2018a有没有,但是2017版本的肯定是没有LSTM工具箱的了),所以版本低的趁这个机会卸载然后重新下载安装吧(╮(╯▽╰)╭)

引用参考
1.matlab官方文档:https://ww2.mathworks.cn/help/deeplearning/examples/classify-sequence-data-using-lstm-networks.html
2.前人博主的教程:https://blog.csdn.net/weixin_43196262/article/details/83106239

1. LSTM

LSTM(Long Short-Term Memory)是长短期记忆网络,是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。其两大用途为classification和regression,本文介绍如何用LSTM做classification。

2.官方例程解释

2.1 数据集

为了使用LSTM,首先我们必须对其适用数据集有充分的了解。
数据集应为时间序列集,且应划分为训练集XTrain与测试集XTest,训练集XTrain要有对应的分类标签集YTrain作为responses。通过XTrain与YTrain训练数据后,输入XTest,我们可以得到YPred(预测标签集),并与YTest进行对比,便可知道训练结果的优良。

下面以UCI数据集中的 Japanese Vowels数据集作为例子进行解释。该数据集已经整理成.mat格式
(网址链接: https://pan.baidu.com/s/1IySfjLIuee6CrV0HN2K7Lw 提取码:3fev)

在matlab中导入,有一个结构数组,打开后可以看到其中包含了名为train,test的元胞数组,名为trainlables,testlables的数值数组,对应上述的XTrain,XTest与YTrain,YTest。

XTrain是一个1x270 cell数组,即其中包含了270个不同类别的时序集,每一个cell单元是一个12-by-N的矩阵,其中每个cell的N可以不一样(不同cell单元之间的时序长度可以不同)。12代表该时序集的维数。(即XTrian必须为元胞数组

YTrain是一个270x1 double数值数组,对应XTrain中每一个元胞单元的类别,共有9类。(注意此时YTrian还不能用,数值数组必须转为类别数组,后面说明

其物理意义可以解释如下:有9个人说出了270段话,每段话的时间长度不一样,每一个声音样本都提取了一个12维的特征向量。因此,对每一个声音样本来说(也就是每一个cell单元),每一列数据代表某一个时间点上的feature vector,其长度为12。将每一列数据沿行方向排列,构成时间尺度。

MATLAB中LSTM时序分类的用法与实战_第1张图片
XTest和YTest同理,不再说明。

2.2 用LSTM网络训练与测试

直接上代码

%% 创建变量
XTrain=mts.train;
YTrain=categorical(mts.trainlabels);%将数值数组转化为类别数组
XTest=mts.test;
YTest=categorical(mts.testlabels);

%% 构建LSTM网络
inputSize = 12;%特征的维度
numHiddenUnits = 100;%LSTM网路包含的隐藏单元数目
numClasses = 9;%label标签的种数,该例子中为人数

layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

maxEpochs = 100;%最大训练周期数
miniBatchSize = 27;%分块尺寸

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress');

%% 训练
net=trainNetwork(XTrain,YTrain,layers, options)

%% 预测
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest')
%% 精确度检验
acc = sum(YPred == YTest)./numel(YTest)

训练结果
MATLAB中LSTM时序分类的用法与实战_第2张图片
检验结果
acc = 0.9324

可见LSTM的分类效果还是挺优良的。
.

3.实战

实战的重点与难点在于如何导入正确的数据
结合本篇与MATLAB中外部数据读取并写入元胞数组的方法与步骤(https://blog.csdn.net/qq_42995378/article/details/87298786) ,鉴于篇幅限制,我将在下一篇文章中给出实战教程与代码。

你可能感兴趣的:(matlab,matlab,LSTM,时间序列,分类)