基于MATLAB 自带手写数字集的CNN(LeNet5)手写数字识别

文章目录

  • 前言
  • 使用步骤
  • 识别结果


前言


利用MATLAB实践MNIST手写数字识别,下载手写数据集的准备工作有些麻烦。MATLAB 2021版可以直接调用MNIST部分数据进行CNN手写数字识别实践。直接上程序。

一、使用步骤

1.训练、测试及保存网络

代码如下(示例):

%CNN 手写数字识别程序
%$2021/8/6 GAVIN$%
%查阅MATLAB trainNetwork()帮助文档
%参阅网址:
%https://blog.csdn.net/weixin_43935696/article/details/109372278
%https://blog.csdn.net/qq_40166660/article/details/111992192


%%读取MATLAB自带数字图像数据集,数据集有10000幅0-9图像,各数字有1000幅图像
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds=imageDatastore(digitDatasetPath, 'FileExtensions',...
{'.png','.jpg','.tif'},'IncludeSubfolders',true,... 
'LabelSource','foldernames');
%countEachLabel(imds)%显示标签文件夹及其中的文件数,10*2 table
%a=readimage(imds,1);%读数据集中的第一幅图像
%whos a %查看图像数据结构及数据类型:28x28,uint8
figure
numImages = 10000;
perm = randperm(numImages,20);%随机选取20个数,perm是1*20double数组
for i = 1:20  %此循环是MATLAB自带示例
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});%显示随机选取的20幅图像
    drawnow;
end
numTrainingFiles = 750;%每个数字图像文件夹中1000幅图像中的750幅用于训练
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');
%%将每个标签文件夹中的文件随机拆分为两组,750个为imdsTrain,其余为imdsTest

%%一般的CNN网络
layers = [
    imageInputLayer([28 28 1],'Name','imageinput')
    
    convolution2dLayer(3,8,'Name','conv1','Padding','same')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')  
    
    maxPooling2dLayer(2,'Stride',2,'Name','pool1')
    
    convolution2dLayer(3,16,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')     
    
    maxPooling2dLayer(2,'Stride',2,'Name','pool2')
    
    convolution2dLayer(3,32,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')  
    
    fullyConnectedLayer(10,'Name','fullc' )
    softmaxLayer('Name','soft')  
    classificationLayer('Name','classoutput')];

%%改进的LeNet-5 网络
% layers = [...        
%     imageInputLayer([28 28 1],"Name","imageinput")
% 
%     convolution2dLayer([5 5],6,"Name","conv1","Padding","same")
%     tanhLayer("Name","tanh1")
% 
%     maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])
% 
%     convolution2dLayer([5 5],16,"Name","conv2")
%     tanhLayer("Name","tanh2")
% 
%     maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])
% 
%     fullyConnectedLayer(120,"Name","fc1")
%     fullyConnectedLayer(84,"Name","fc2")
%     fullyConnectedLayer(10,"Name","fc")
%     softmaxLayer("Name","softmax")
%     classificationLayer("Name","classoutput")];

figure;plot(layerGraph(layers))%显示网络结构图

options = trainingOptions('sgdm', ...         %MATLAB示例优化器
    'LearnRateSchedule','piecewise', ...      %学习率
    'LearnRateDropFactor',0.2, ...             
    'LearnRateDropPeriod',5, ...
    'MaxEpochs',20, ...           %最大学习整个数据集的次数
    'MiniBatchSize',128, ...  %一个batch有128个样本,训练一轮样本需要迭代7500/128次
    'Plots','training-progress',... %画出整个训练过程
    'Verbose',0);%不在命令窗显示训练过程信息             

net = trainNetwork(imdsTrain,layers,options);
%save Minist_LeNet5 net 
save zzg_cnn_1 net 

YPred = classify(net,imdsTest);%在测试集上输出网络预测结果
YTest = imdsTest.Labels;
accuracy=sum(YPred == YTest)/numel(YTest);%网络在测试集的精度
disp(accuracy);%0.9936,训练耗时10'54"
%LeNet-5:精度0.9820,训练耗时7'40"

2.手写数字识别

代码如下(示例):

%测试自己在画图板上写的数字
%load Minist_LeNet5 net;   %导入训练好的LeNet5网络
load zzg_cnn_1 net %导入训练好的CNN网络
imds_zzg=imageDatastore('F:\2021MATLAB\shouxieshibie\zixieshuzi',...
'FileExtensions',{'.jpg','.png','.bmp'},'IncludeSubfolders',false,... 
'LabelSource','foldernames');
len=numel(imds_zzg.Files);%读取imds_zzg中文件个数
%len=length(readall(imds_zzg));%读取imds_zzg中文件个数

figure('Name','CNN手写数字识别,版权所有:周志刚','NumberTitle','off');
for i = 1:len 
    test_image =readimage(imds_zzg,i); %导入手写体数字图片
    subplot(8,5,i);
    imshow(imds_zzg.Files{i});
    pause(1)
    %drawnow;
    shape = size(test_image);
    dimension=numel(shape);
    if dimension > 2
    test_image = rgb2gray(test_image);      %灰度化
    end
    test_image = imresize(test_image, [28,28]); %保证输入为28*28
    test_image = imcomplement(test_image);%反转,使得输入网络时一定要保证图片 背景是黑色,数字部分是白色
    result = classify(net, test_image);  
    disp(result);%在命令窗显示识别结果
    title(['CNN识别结果:' char(result)])%在手写数字图像标题显示识别结果
end

二 识别结果

基于MATLAB 自带手写数字集的CNN(LeNet5)手写数字识别_第1张图片

你可能感兴趣的:(机器学习)