基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)

基本介绍

  • 软件:Matlab R2018b
  • 数据集:MNIST手写体数字数据集
  • 网络:自建简单网络

数据准备

MNIST数据集还挺有名的,这里就不过多介绍了。数据集本身读取格式官网有给,怎么转换成图片格式网上也有很多,这里不再赘述。
官网:http://yann.lecun.com/exdb/mnist/
训练集包含60000个示例,测试集包含10000个示例。
测试集的前5000个示例来自原始的NIST训练集。 最后的5000个来自原始的NIST测试集。 前5000个比后5000个更干净点,识别起来更容易。
当然为了方便使用MATLAB,这里给出程序缺省的数据集:
链接:https://pan.baidu.com/s/1VItI8MdUa-oBhWjKUUB72w
提取码:tgv9
CSDN地址:https://download.csdn.net/download/garker/12413315
每一个数字都包含1000张图片,每张图片大小均为28×28×1,1代表单通道,即灰度图。
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第1张图片

神经网络组建

因为数据集本身特征并不多,所以不需要动用常用的神经网络,这里给出一个官方的结构形式。一共有15层。
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第2张图片
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第3张图片
这里可以看出,三层卷积,三层归一化,是相当简单的CNN网络结构了,可以当作CNN结构的入门学习好好钻研学习。
在MATLAB中的建构代码如下:

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

这其中,各层的参数如下:
convolution2dLayer

参数 含义
FilterSize 3,3 卷积核尺寸
NumFilter 8 卷积核数量
Padding ‘same’ new_height = new_width = W / S (结果向上取整)

(W×W的输入矩阵,F×F的卷积核,步长为S=1)

BatchNormalizationLayer
归一化层采用默认数据

maxPooling2dLayer

参数 含义
PoolSize 2,2 池化尺寸
Stride 2,2 步长

fullyConnectedLayer
全连接层输出为10(0-9共10个数字)

训练神经网络

imds = imageDatastore('train_dataset', ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
%导入数据

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,'randomize');
%分割数据集与测试集

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',5, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');
%设置训练参数

net = trainNetwork(imdsTrain,layers,options);
%训练神经网络

基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第4张图片
这里可以看出来基本上第三个世代就已经训练差不多了,最后的accuracy也能达到99.80%

测试数据集

YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);

figure;
perm = randperm(10000,20);
for i = 1:20
    subplot(4,5,i);
    s = classify(net,imread(imds.Files{perm(i)}));
    imshow(imds.Files{perm(i)});title(string(s));
end

随机挑出来20个看看效果,没什么大问题:
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第5张图片

鼠绘输入识别的GUI

GUI的代码编写不算难,直接回调函数里面编写也比较方便。这里着重讲一下鼠绘的问题,网上查了很多资料也踩了不少坑,这里按处理顺序把比较坑的细节都放一下:

鼠绘区域

基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第6张图片
红色区域里面只有axes1是有实际作用的,为了美观我把X、Y轴颜色改成了背景的灰色以达到隐藏的效果。此外,还需要把X、Y轴XLimMode、YLimMode设置为manual,其主要作用是锁住它们,不然在鼠绘的时候每一笔都会飘。
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第7张图片
此外,对该区域的鼠绘效果显示代码如下:

figure1_WindowButtonDownFcn

unction figure1_WindowButtonDownFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
draw_enable=1;
if draw_enable
    position=get(gca,'currentpoint');
    x(1)=position(1);
    y(1)=position(3);
end

figure1_WindowButtonMotionFcn

function figure1_WindowButtonMotionFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
if draw_enable
    position=get(gca,'currentpoint');
    x(2)=position(1);
    y(2)=position(3);
    h1 = line(x,y,'EraseMode','xor','LineWidth',5,'color','black');
    x(1)=x(2);
    y(1)=y(2);
end

figure1_WindowButtonUpFcn

function figure1_WindowButtonUpFcn(hObject, eventdata, handles)
% hObject    handle to figure1 (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global draw_enable
draw_enable=0;

特别特别需要注意的是,这三个回调函数都是在整个GUI默认的整体面板上来的,也就是figure1。具体找到这个回调函数的如下图所示:
基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第8张图片
没错,就是点击GUI编辑面板空白区域!

识别

识别按钮的回调函数很简单这里就不赘述了,需要特别提醒的是:
从绘制区域直接得到的并不是可直接使用图像数据,这里直接保存到默认目录一份正好也做备份用;
再者,保存好的图像的手写数据部分是深色的,背景部分是浅色的,这与我们之前的训练数据是不符的,直接用来识别肯定不会出现正确的答案,所以把这个数据读取之后再取反色,部分代码如下:

h=getframe(handles.axes1);
imwrite(h.cdata,'output.jpg','jpg');
img = imread('output.jpg');
img = imresize(img,[28,28]);
img = rgb2gray(img);
img = 255 - img; %取反色

基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)_第9张图片

结论

“0-9”这十个数字逐一写了一遍感觉问题不大,但是千万别因为鼠绘区域大懒省事儿把数字写的很小这会影响到识别结果,如果实在感觉控制不好,可以在GUI编辑界面把整个界面改成按比例,这样实际使用的时候可以等比例把界面拉小,鼠绘更方便一些。

你可能感兴趣的:(Matlab神经网络实战)