MNIST数据集还挺有名的,这里就不过多介绍了。数据集本身读取格式官网有给,怎么转换成图片格式网上也有很多,这里不再赘述。
官网:http://yann.lecun.com/exdb/mnist/
训练集包含60000个示例,测试集包含10000个示例。
测试集的前5000个示例来自原始的NIST训练集。 最后的5000个来自原始的NIST测试集。 前5000个比后5000个更干净点,识别起来更容易。
当然为了方便使用MATLAB,这里给出程序缺省的数据集:
链接:https://pan.baidu.com/s/1VItI8MdUa-oBhWjKUUB72w
提取码:tgv9
CSDN地址:https://download.csdn.net/download/garker/12413315
每一个数字都包含1000张图片,每张图片大小均为28×28×1,1代表单通道,即灰度图。
因为数据集本身特征并不多,所以不需要动用常用的神经网络,这里给出一个官方的结构形式。一共有15层。
这里可以看出,三层卷积,三层归一化,是相当简单的CNN网络结构了,可以当作CNN结构的入门学习好好钻研学习。
在MATLAB中的建构代码如下:
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
这其中,各层的参数如下:
convolution2dLayer
参数 | 值 | 含义 |
---|---|---|
FilterSize | 3,3 | 卷积核尺寸 |
NumFilter | 8 | 卷积核数量 |
Padding | ‘same’ | new_height = new_width = W / S (结果向上取整) |
(W×W的输入矩阵,F×F的卷积核,步长为S=1)
BatchNormalizationLayer
归一化层采用默认数据
maxPooling2dLayer
参数 | 值 | 含义 |
---|---|---|
PoolSize | 2,2 | 池化尺寸 |
Stride | 2,2 | 步长 |
fullyConnectedLayer
全连接层输出为10(0-9共10个数字)
imds = imageDatastore('train_dataset', ...
'IncludeSubfolders',true,'LabelSource','foldernames');
%导入数据
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,'randomize');
%分割数据集与测试集
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',5, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
%设置训练参数
net = trainNetwork(imdsTrain,layers,options);
%训练神经网络
这里可以看出来基本上第三个世代就已经训练差不多了,最后的accuracy也能达到99.80%。
YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation);
figure;
perm = randperm(10000,20);
for i = 1:20
subplot(4,5,i);
s = classify(net,imread(imds.Files{perm(i)}));
imshow(imds.Files{perm(i)});title(string(s));
end
GUI的代码编写不算难,直接回调函数里面编写也比较方便。这里着重讲一下鼠绘的问题,网上查了很多资料也踩了不少坑,这里按处理顺序把比较坑的细节都放一下:
红色区域里面只有axes1是有实际作用的,为了美观我把X、Y轴颜色改成了背景的灰色以达到隐藏的效果。此外,还需要把X、Y轴的XLimMode、YLimMode设置为manual,其主要作用是锁住它们,不然在鼠绘的时候每一笔都会飘。
此外,对该区域的鼠绘效果显示代码如下:
figure1_WindowButtonDownFcn
unction figure1_WindowButtonDownFcn(hObject, eventdata, handles)
% hObject handle to figure1 (see GCBO)
% eventdata reserved - to be defined in a future version of MATLAB
% handles structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
draw_enable=1;
if draw_enable
position=get(gca,'currentpoint');
x(1)=position(1);
y(1)=position(3);
end
figure1_WindowButtonMotionFcn
function figure1_WindowButtonMotionFcn(hObject, eventdata, handles)
% hObject handle to figure1 (see GCBO)
% eventdata reserved - to be defined in a future version of MATLAB
% handles structure with handles and user data (see GUIDATA)
global draw_enable;
global x;
global y;
if draw_enable
position=get(gca,'currentpoint');
x(2)=position(1);
y(2)=position(3);
h1 = line(x,y,'EraseMode','xor','LineWidth',5,'color','black');
x(1)=x(2);
y(1)=y(2);
end
figure1_WindowButtonUpFcn
function figure1_WindowButtonUpFcn(hObject, eventdata, handles)
% hObject handle to figure1 (see GCBO)
% eventdata reserved - to be defined in a future version of MATLAB
% handles structure with handles and user data (see GUIDATA)
global draw_enable
draw_enable=0;
特别特别需要注意的是,这三个回调函数都是在整个GUI默认的整体面板上来的,也就是figure1。具体找到这个回调函数的如下图所示:
没错,就是点击GUI编辑面板空白区域!
识别按钮的回调函数很简单这里就不赘述了,需要特别提醒的是:
从绘制区域直接得到的并不是可直接使用图像数据,这里直接保存到默认目录一份正好也做备份用;
再者,保存好的图像的手写数据部分是深色的,背景部分是浅色的,这与我们之前的训练数据是不符的,直接用来识别肯定不会出现正确的答案,所以把这个数据读取之后再取反色,部分代码如下:
h=getframe(handles.axes1);
imwrite(h.cdata,'output.jpg','jpg');
img = imread('output.jpg');
img = imresize(img,[28,28]);
img = rgb2gray(img);
img = 255 - img; %取反色
“0-9”这十个数字逐一写了一遍感觉问题不大,但是千万别因为鼠绘区域大懒省事儿把数字写的很小这会影响到识别结果,如果实在感觉控制不好,可以在GUI编辑界面把整个界面改成按比例,这样实际使用的时候可以等比例把界面拉小,鼠绘更方便一些。