matlab 2022a predict()函数的使用方法--深度学习

predict

使用训练过的深度学习神经网络预测响应.

Syntax 

Y = predict(net,images)
Y = predict(net,sequences)
Y = predict(net,features)
Y = predict(net,X1,...,XN)
Y = predict(net,mixed)
[Y1,...,YM] = predict(___)
___ = predict(___,Name=Value)

 

描述

你可以在CPU或GPU上使用经过训练的深度学习神经网络进行预测。使用GPU需要并行计算工具箱™和支持的GPU设备。有关支持的设备的信息,请参见GPU计算要求(并行计算工具箱)。使用ExecutionEnvironment名称-值参数指定硬件需求。

使用此函数可以使用经过训练的SeriesNetwork或DAGNetwork对象预测响应。有关使用dlnetwork对象预测响应的信息,请参见predict。 

 

Y = predict(net,images) predicts the responses of the specified images using the trained network net.

example
Y = predict(net,sequences) predicts the responses of the specified sequences using the trained network net.

Y = predict(net,features) predicts the responses of the specified feature data using the trained network net.

Y = predict(net,X1,...,XN) predicts the responses for the data in the numeric or cell arrays X1, …, XN for the multi-input network net. The input Xi corresponds to the network input net.InputNames(i).

Y = predict(net,mixed) predicts the responses using the trained network net with multiple inputs of mixed data types.

[Y1…,YM] = predict(___)使用前面的任何输入参数预测多输出网络的M个输出的响应。输出Yj对应于网络输出net.OutputNames(j)。要为分类输出层返回分类输出,请将ReturnCategorical选项设置为1 (true)。

 ___ = predict(___,Name=Value) predicts the responses with additional options specified by one or more name-value arguments.

 例子:

使用训练过的卷积神经网络预测数字响应

 加载预先训练的网络digitsRegressionNet。该网络是一种预测手写体数字旋转角度的回归卷积神经网络。

load digitsRegressionNet

查看网络层次。网络的输出层是一个回归层。

layers = net.Layers

输出结果:

layers = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             2-D Convolution       8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        2-D Average Pooling   2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             2-D Convolution       16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        2-D Average Pooling   2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             2-D Convolution       32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             2-D Convolution       32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

加载测试图像。

XTest = digitTest4DArrayData;

利用预测函数预测输入数据的响应。

YTest = predict(net,XTest);

随机查看一些测试图像及其预测。

numPlots = 9;
idx = randperm(size(XTest,4),numPlots);

sz = size(XTest,1);
offset = sz/2;

figure
tiledlayout("flow")

for i = 1:numPlots
    nexttile
    imshow(XTest(:,:,:,idx(i)))
    title("Observation " + idx(i))

    hold on
    plot(offset*[1-tand(YTest(idx(i))) 1+tand(YTest(idx(i)))],[sz 0],"r--")
    hold off
end

matlab 2022a predict()函数的使用方法--深度学习_第1张图片

 利用训练过的LSTM网络预测序列的数值响应

加载预训练的网络freqNet。该网络是一个预测波形频率的LSTM回归神经网络。

load freqNet

 查看网络层次。网络的输出层是一个回归层。

net.Layers
ans = 
  4x1 Layer array with layers:

     1   'sequenceinput'      Sequence Input      Sequence input with 3 dimensions
     2   'lstm'               LSTM                LSTM with 100 hidden units
     3   'fc'                 Fully Connected     1 fully connected layer
     4   'regressionoutput'   Regression Output   mean-squared-error with response 'Response'

加载测试序列。

load WaveformData
X = data;

利用预测函数预测输入数据的响应。因为网络是使用截断到每个小批的最短序列长度的序列进行训练的,所以也可以通过将SequenceLength选项设置为“最短”来截断测试序列。

Y = predict(net,X,SequenceLength="shortest");

在一个图中想象最初的几个预测。

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(X{i}',DisplayLabels="Channel " + (1:3))

    xlabel("Time Step")
    title("Predicted Frequency: " + string(Y(i)))
end

matlab 2022a predict()函数的使用方法--深度学习_第2张图片

 

Predict responses using trained deep learning neural network - MATLAB predict - MathWorks 中国

你可能感兴趣的:(MATLAB,matlab,深度学习,开发语言)