matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络

加载数据

按照 [1] 和 [2] 中的说明加载日语元音数据集。预测变量是包含不同长度序列的元胞数组,特征维度为 12。标签是由标签 1、2、...、9 组成的分类向量。

[XTrain,YTrain] = japaneseVowelsTrainData;

[XValidation,YValidation] = japaneseVowelsTestData;

查看前几个训练序列的大小。序列是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。

XTrain(1:5)

ans=5×1 cell array

{12×20 double}

{12×26 double}

{12×22 double}

{12×20 double}

{12×21 double}

定义网络架构

打开深度网络设计器。

deepNetworkDesigner

在序列到标签上暂停,然后点击打开。这会打开一个适合序列分类问题的预置网络。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第1张图片

深度网络设计器显示该预置网络。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第2张图片

您可以轻松地将此序列网络用于日语元音字母数据集。

选择 sequenceInputLayer,检查并确认 InputSize 设置为 12,与特征维度匹配。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第3张图片

选择 lstmLayer 并将 NumHiddenUnits 设置为 100。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第4张图片

选择 fullyConnectedLayer,检查并确认 OutputSize 设置为 9,即类的数目。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第5张图片

检查网络架构

要检查网络并查看层的详细信息,请点击分析。

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第6张图片

导出网络架构

要将网络架构导出到工作区,请在设计器选项卡上,点击导出。深度网络设计器将网络保存为变量 layers_1。

您还可以通过选择导出 > 生成代码来生成用于构造网络架构的代码。

训练网络

指定训练选项并训练网络。

由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 'ExecutionEnvironment' 设置为 'cpu'。要在 GPU(如果可用)上进行训练,请将 'ExecutionEnvironment' 设置为 'auto'(默认值)。

miniBatchSize = 27;

options = trainingOptions('adam', ...

'ExecutionEnvironment','cpu', ...

'MaxEpochs',100, ...

'MiniBatchSize',miniBatchSize, ...

'ValidationData',{XValidation,YValidation}, ...

'GradientThreshold',2, ...

'Shuffle','every-epoch', ...

'Verbose',false, ...

'Plots','training-progress');

训练网络。

net = trainNetwork(XTrain,YTrain,layers_1,options);

matlab中如何创建网络,使用深度网络设计器创建简单的序列分类网络_第7张图片

测试网络

对测试数据进行分类,并计算分类准确度。指定与训练相同的小批量大小。

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);

acc = mean(YPred == YValidation)

acc = 0.9432

在接下来的步骤中,您可以尝试通过使用双向 LSTM (BiLSTM) 层或创建更深的网络来提高准确度。有关详细信息,请参阅长短期记忆网络。

有关说明如何使用卷积网络对序列数据进行分类的示例,请参阅使用深度学习进行语音命令识别。

参考资料

[1] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo.“Multidimensional Curve Classification Using Passing-through Regions.”Pattern Recognition Letters 20, no. 11–13 (November 1999):1103–11. https://doi.org/10.1016/S0167-8655(99)00077-X.

[2] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo.Japanese Vowels Data Set.Distributed by UCI Machine Learning Repository. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

你可能感兴趣的:(matlab中如何创建网络)