使用MATLAB的trainNetwork设计一个简单的LSTM神经网络

文章目录

  • 前言
  • 一、数据集
  • 二、网络结构
  • 三、测试程序


前言

借助MATLAB的deepNetworkDesigner搭一个简单的LSTM,数据集使用mnist手写数字识别数据集。


一、数据集

mnist数据集包括60000组训练数据和对应的标签,10000组测试数据和对应标签。每个数据都是一个28x28的矩阵,可以将其看做28x28像素的灰度图像(黑底白字)。而LSTM的输入应当是一个序列,我们可以把矩阵的每一行当做一帧,把图像分为28帧输入到LSTM。
数据集可以在我上传的资源里找到。

数据的格式是这样的:

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络_第1张图片
XTrain,即训练图像,是一个60000x1的cell,cell的每一个元素是一个28x28的矩阵。矩阵的每一列为一帧。直接将矩阵以图片显示是这样的:

 imshow(cell2mat(XTrain(8)))

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络_第2张图片
这不是某希腊字母,而是手写数字3。我们希望按行输入,而MATLAB按列读取,因此我做了个转置。再转置一下就能看到正常的图像:

 imshow(cell2mat(XTrain(8))')

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络_第3张图片
标签的格式为:

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络_第4张图片
可以直接通过categorical函数实现数值到categorical的转换,比如:

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络_第5张图片

输入训练数据的方式不唯一,我用的只是其中一种,详情见MathWorks官网:trainNetwork

二、网络结构

使用一层128个隐藏节点的LSTM,一层全连接,输出使用softmax。网络的输入是一个序列,输出是标签,在MATLAB中,此网络可以这样描述:

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

三、测试程序

完整的测试程序如下:

clear
clc
%加载数据
load('.\mnist_data_mat\XTrain.mat')
load('.\mnist_data_mat\YTrain.mat')
load('.\mnist_data_mat\XTest.mat')
load('.\mnist_data_mat\YTest.mat')

%设置参数
inputSize = 28;         %28个输入节点
numHiddenUnits = 128;   %128个隐藏节点
numClasses = 10;        %10种分类结果

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',5, ...
    'MiniBatchSize',60, ...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress');

net=trainNetwork(XTrain,YTrain,layers, options);    %训练

Y_pred = classify(net, XTest);                      %测试
accy = sum(Y_pred == YTest) / length(YTest);        %计算准确度

准确度为97.73%
options里的参数可以修改一下,我用同样结构的网络不同的参数做出了98.74%的准确度,仍有提升空间。这里为了节省训练时间牺牲了一些精度。
训练好的网络也上传到了资源里。

你可能感兴趣的:(MATLAB,matlab,神经网络,lstm)