文章目录
利用MATLAB实践MNIST手写数字识别,下载手写数据集的准备工作有些麻烦。MATLAB 2021版可以直接调用MNIST部分数据进行CNN手写数字识别实践。直接上程序。
代码如下(示例):
%CNN 手写数字识别程序
%$2021/8/6 GAVIN$%
%查阅MATLAB trainNetwork()帮助文档
%参阅网址:
%https://blog.csdn.net/weixin_43935696/article/details/109372278
%https://blog.csdn.net/qq_40166660/article/details/111992192
%%读取MATLAB自带数字图像数据集,数据集有10000幅0-9图像,各数字有1000幅图像
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
'nndemos','nndatasets','DigitDataset');
imds=imageDatastore(digitDatasetPath, 'FileExtensions',...
{'.png','.jpg','.tif'},'IncludeSubfolders',true,...
'LabelSource','foldernames');
%countEachLabel(imds)%显示标签文件夹及其中的文件数,10*2 table
%a=readimage(imds,1);%读数据集中的第一幅图像
%whos a %查看图像数据结构及数据类型:28x28,uint8
figure
numImages = 10000;
perm = randperm(numImages,20);%随机选取20个数,perm是1*20double数组
for i = 1:20 %此循环是MATLAB自带示例
subplot(4,5,i);
imshow(imds.Files{perm(i)});%显示随机选取的20幅图像
drawnow;
end
numTrainingFiles = 750;%每个数字图像文件夹中1000幅图像中的750幅用于训练
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');
%%将每个标签文件夹中的文件随机拆分为两组,750个为imdsTrain,其余为imdsTest
%%一般的CNN网络
layers = [
imageInputLayer([28 28 1],'Name','imageinput')
convolution2dLayer(3,8,'Name','conv1','Padding','same')
batchNormalizationLayer('Name','bn1')
reluLayer('Name','relu1')
maxPooling2dLayer(2,'Stride',2,'Name','pool1')
convolution2dLayer(3,16,'Padding','same','Name','conv2')
batchNormalizationLayer('Name','bn2')
reluLayer('Name','relu2')
maxPooling2dLayer(2,'Stride',2,'Name','pool2')
convolution2dLayer(3,32,'Padding','same','Name','conv3')
batchNormalizationLayer('Name','bn3')
reluLayer('Name','relu3')
fullyConnectedLayer(10,'Name','fullc' )
softmaxLayer('Name','soft')
classificationLayer('Name','classoutput')];
%%改进的LeNet-5 网络
% layers = [...
% imageInputLayer([28 28 1],"Name","imageinput")
%
% convolution2dLayer([5 5],6,"Name","conv1","Padding","same")
% tanhLayer("Name","tanh1")
%
% maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])
%
% convolution2dLayer([5 5],16,"Name","conv2")
% tanhLayer("Name","tanh2")
%
% maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])
%
% fullyConnectedLayer(120,"Name","fc1")
% fullyConnectedLayer(84,"Name","fc2")
% fullyConnectedLayer(10,"Name","fc")
% softmaxLayer("Name","softmax")
% classificationLayer("Name","classoutput")];
figure;plot(layerGraph(layers))%显示网络结构图
options = trainingOptions('sgdm', ... %MATLAB示例优化器
'LearnRateSchedule','piecewise', ... %学习率
'LearnRateDropFactor',0.2, ...
'LearnRateDropPeriod',5, ...
'MaxEpochs',20, ... %最大学习整个数据集的次数
'MiniBatchSize',128, ... %一个batch有128个样本,训练一轮样本需要迭代7500/128次
'Plots','training-progress',... %画出整个训练过程
'Verbose',0);%不在命令窗显示训练过程信息
net = trainNetwork(imdsTrain,layers,options);
%save Minist_LeNet5 net
save zzg_cnn_1 net
YPred = classify(net,imdsTest);%在测试集上输出网络预测结果
YTest = imdsTest.Labels;
accuracy=sum(YPred == YTest)/numel(YTest);%网络在测试集的精度
disp(accuracy);%0.9936,训练耗时10'54"
%LeNet-5:精度0.9820,训练耗时7'40"
代码如下(示例):
%测试自己在画图板上写的数字
%load Minist_LeNet5 net; %导入训练好的LeNet5网络
load zzg_cnn_1 net %导入训练好的CNN网络
imds_zzg=imageDatastore('F:\2021MATLAB\shouxieshibie\zixieshuzi',...
'FileExtensions',{'.jpg','.png','.bmp'},'IncludeSubfolders',false,...
'LabelSource','foldernames');
len=numel(imds_zzg.Files);%读取imds_zzg中文件个数
%len=length(readall(imds_zzg));%读取imds_zzg中文件个数
figure('Name','CNN手写数字识别,版权所有:周志刚','NumberTitle','off');
for i = 1:len
test_image =readimage(imds_zzg,i); %导入手写体数字图片
subplot(8,5,i);
imshow(imds_zzg.Files{i});
pause(1)
%drawnow;
shape = size(test_image);
dimension=numel(shape);
if dimension > 2
test_image = rgb2gray(test_image); %灰度化
end
test_image = imresize(test_image, [28,28]); %保证输入为28*28
test_image = imcomplement(test_image);%反转,使得输入网络时一定要保证图片 背景是黑色,数字部分是白色
result = classify(net, test_image);
disp(result);%在命令窗显示识别结果
title(['CNN识别结果:' char(result)])%在手写数字图像标题显示识别结果
end