matlab提供了导入数据集的指令,数据集是内置的,可以自行查看digitTrain4DArrayData和digitTest4DArrayData的源码找到数据集的存储位置。
[XTrain,~,YTrain]=digitTrain4DArrayData;
[XValidation,~,YValidation]=digitTest4DArrayData;
可以随机查看几张数字
numTrainImages=numel(YTrain);
figure
idx=randperm(numTrainImages,20);
for ii=1:numel(idx)
subplot(4,5,ii)
imshow(XTrain(:,:,:,idx(ii)))
drawnow
end
构建网络的代码为,其中各种指令的含义见这篇文章
model=[
imageInputLayer([28,28,1])
convolution2dLayer([3,3],8,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer([2,2],'Stride',2)
convolution2dLayer([3,3],16,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer([2,2],'Stride',2)
convolution2dLayer([3,3],32,'Padding','same')
batchNormalizationLayer
reluLayer
convolution2dLayer([3,3],64,'Padding','same')
batchNormalizationLayer
reluLayer
dropoutLayer(0.2)
fullyConnectedLayer(1)
regressionLayer % 预测倾斜角度是回归操作,不是分类操作
];
训练网络之前需要指定训练的参数,这里设置为,其中没有说明的参数见这篇文章
options=trainingOptions("sgdm", ...
'MiniBatchSize',128, ...
'MaxEpochs',30, ...
'InitialLearnRate',0.001, ... % 初始学习率
'LearnRateSchedule','piecewise', ... % 学习率下降方式
'LearnRateDropFactor',0.1, ... % 学习率下降因子,就是学习率每次更新的时候都乘以0.1
'LearnRateDropPeriod',20, ... % 学习率下降周期,每迭代20次就下降一次
'Shuffle','every-epoch', ...
'ValidationData',{XValidation,YValidation}, ...
'ValidationFrequency',floor(numel(YTrain)/128), ...
'Plots','training-progress', ...
'verbose',true);
随后便可以训练网络
net=trainNetwork(XTrain,YTrain,model,options)
YPredicted=predict(net,XValidation);
predictError=YValidation-YPredicted;
RMSE=sqrt(mean(predictError.^2))
[XTrain,~,YTrain]=digitTrain4DArrayData;
[XValidation,~,YValidation]=digitTest4DArrayData;
numTrainImages=numel(YTrain);
figure
idx=randperm(numTrainImages,20);
for ii=1:numel(idx)
subplot(4,5,ii)
imshow(XTrain(:,:,:,idx(ii)))
drawnow
end
model=[
imageInputLayer([28,28,1])
convolution2dLayer([3,3],8,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer([2,2],'Stride',2)
convolution2dLayer([3,3],16,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer([2,2],'Stride',2)
convolution2dLayer([3,3],32,'Padding','same')
batchNormalizationLayer
reluLayer
convolution2dLayer([3,3],64,'Padding','same')
batchNormalizationLayer
reluLayer
dropoutLayer(0.2)
fullyConnectedLayer(1)
regressionLayer % 预测倾斜角度是回归操作,不是分类操作
];
options=trainingOptions("sgdm", ...
'MiniBatchSize',128, ...
'MaxEpochs',30, ...
'InitialLearnRate',0.001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropFactor',0.1, ...
'LearnRateDropPeriod',20, ...
'Shuffle','every-epoch', ...
'ValidationData',{XValidation,YValidation}, ...
'ValidationFrequency',floor(numel(YTrain)/128), ...
'Plots','training-progress', ...
'verbose',true);
net=trainNetwork(XTrain,YTrain,model,options)
YPredicted=predict(net,XValidation);
predictError=YValidation-YPredicted;
RMSE=sqrt(mean(predictError.^2))