matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)

此示例说明如何使用长短期记忆 (LSTM) 网络对序列数据进行分类。

要训练深度神经网络以对序列数据进行分类,可以使用 LSTM 网络。LSTM 网络允许您将序列数据输入网络,并根据序列数据的各个时间步进行预测。

此示例使用 [1] 和 [2] 中所述的日语元音数据集。此示例训练一个 LSTM 网络,旨在根据表示连续说出的两个日语元音的时序数据来识别说话者。训练数据包含九个说话者的时序数据。每个序列有 12 个特征,且长度不同。该数据集包含 270 个训练观测值和 370 个测试观测值。

源码:


%% 通用matlab脚本三连
clear
clc
close all

%% 加载序列数据
% 加载日语元音训练数据。XTrain 是包含 270 个不同长度的 12 维序列的元胞数组。
% Y 是对应于九个说话者的标签 "1"、"2"、...、"9" 的分类向量。
% XTrain 中的条目是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)

%% 在绘图中可视化第一个时序。每行对应一个特征。
figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
legend("Feature " + string(1:12),'Location','northeastoutside')

%% 准备要填充的数据
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

%% 按序列长度对数据进行排序。
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

%% 在条形图中查看排序的序列长度。
figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%% 选择小批量大小 27 以均匀划分训练数据,并减少小批量中的填充量。下图说明了添加到序列中的填充。
miniBatchSize = 27;

%% 定义 LSTM 网络架构
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]

maxEpochs = 100;
miniBatchSize = 27;

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

%% 测试 LSTM 网络
[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)

%% LSTM 网络 net 已使用相似长度的小批量序列进行训练
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

%% 对测试数据进行分类
miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');

%% 计算预测值的分类准确度。
acc = sum(YPred == YTest)./numel(YTest)

运行结果:

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)_第1张图片

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)_第2张图片

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)_第3张图片

本文着重讲解一下matlab的深度学习训练方式

在matlab中,训练一个网络有一个常用的函数trainNetwork;此函数通常有四个参数

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

 其中,XTrain为训练的输入数据集,YTrain为训练数据对应的标签;

layers为网络的架构,也就是指网络每层的处理模式。

options为超参数的设置,包括学习率,优化方法,迭代次数以及批量大小问题等。

1. 构建网络模型;

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize) %输入层
    bilstmLayer(numHiddenUnits,'OutputMode','last') %第一层隐层,LSTM架构
    fullyConnectedLayer(numClasses) % 第二层隐层,全连接层
    softmaxLayer %softmax处理
    classificationLayer] %分类层

 如上代码片段所示,逐行对每一层的网络处理模式进行说明,并将此构成的数组形式赋予一个变量(此中为layers)。

2. 指定训练的超参数

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

其中参数设定的格式必须由trainingOptions进行格式打包,即options的数据类型必须为TrainingOptions派,此示例中options的数据类型为TrainingOptionsADAM。另外本示例数据集方面的XTrain与XTest都是matlab中的cell数据类型。

 更详细的matlab深度学习训练方法(trainNetwork用法)请参考matlab官方文档或本博客Deep Learning ToolBox系列7。

 

你可能感兴趣的:(编程,深度学习,matlab)