Matlab中的深度学习——CNN图像分类实例

我测试了一下,matlab2022a 可以运行。matlab2018之后的应该都可以。因为CNN是比较简单,比较经典的网络。matlab很早就支持了

一起来学习一个Matlab环境下运用卷积神经神经网络(CNN)实现图像分类的小栗子,力争用最通俗的语言进行解释,闲话不多说,咱直接上干货。

环境要求:Matlab2018a及以上版本(18a版本开始提供Deep Learning Toolbox的工具箱,正式拉开了Matlab进军深度学习领域的序幕);

 

数据集:大多图像分类的Benchmark数据集都会使用到MNIST,但是由于使用的实在太多了,以及分类的难度确实不大,所以我们今天分享一个相似的但是难度稍大的数据集,名字叫notMNIST,附上数据集的链接:

http://yaroslavvb.blogspot.com

Matlab中的深度学习——CNN图像分类实例_第1张图片 

 

 Matlab中的深度学习——CNN图像分类实例_第2张图片

 

代码部分:

为了尽可能的把demo写的浅显易懂,也为了其他小伙伴可以直接套用此模型来跑自己的数据,我们尽可能地不使用复杂晦涩的函数,努力做到步子小一点,注释多一点。

demo中使用的数据集(下载后和程序放在一个文件夹下,不用改路径,直接运行程序即可):

链接:https://pan.baidu.com/s/1JnS6yfnoxTNg1d0_atpGHg

提取码:hw89

% 数据集来源:
% http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html
% matlab deep learning toolbox 官方使用说明:
% https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#d120e82918
% 知乎:夏明朗   2020-05-29

clc
clear
close all

% 读取数据
load('notMNIST_small.mat')

% 选取部分数据可视化
for i=1:1:32
    subplot(4,8,i);
    imshow(images(:,:,555*i)/255)
end

% 
X = reshape(images, [28,28,1,length(images)]); % 灰度图片的作为输入的要求为h*w**c*s,
                                               % 其中h为图片的长度,w为宽度,c为通道数,s就是数据的个数
size(X)                                        % 可以看到我们的数据集尺寸为:28*28*1*18724
Y = categorical(labels);                       % 标签的数据类型为categorical

idx = randperm(length(images));   % 产生一个和数据个数一致的随机数序列
num_train = round(0.5*length(X)); % 训练集个数,0.5表示全部数据中随机选取50%作为训练集
num_val = round(0.3*length(X));   % 验证集个数,0.3表示全部数据中随机选取30%作为验证集,故测试集自动变为剩下的20%

% 训练集,验证集和测试集数据
X_train = X(:,:,:,idx(1:num_train));
X_val = X(:,:,:,idx(num_train+1:num_train+num_val));
X_test = X(:,:,:,idx(num_train+num_val+1:end));  %这里假设,全部数据中除了

% 训练集,验证集和测试集标签
Y_train = Y(idx(1:num_train),:);
Y_val = Y(idx(num_train+1:num_train+num_val),:);
Y_test = Y(idx(num_train+num_val+1:end),:);

 以上代码主要实现了数据的导入以及数据集的随机划分,未使用一步到位的函数,这样大家就可以在工作区清楚的看到每一步是怎样实现的。这里解释randperm( )函数,其可随机产生一个和样本个数相同数目的序列,例如共有10个样本,那么randperm(10)的结果可能为

8     3    10     5     2     4     1     7     9     6

 因此我们使用这个随机序列就可以实现数据集的随机划分,本程序设置训练集数据占50%,验证集数据占30%,测试集数据就占剩余的20%。数据集准备完成后,我们就开始搭建CNN网络

%% 定义网络层
layers = [...
          imageInputLayer([28,28,1]); % 输入层,要正确输入图片的height, width 和 number of channels of the images
          batchNormalizationLayer();  % 批量归一化
          convolution2dLayer(5,20);   % 卷积层
          batchNormalizationLayer();
          reluLayer()                 % Relu激活函数 
          maxPooling2dLayer(2,'Stride',2); % 池化层
          fullyConnectedLayer(10);       % 全连接层
          softmaxLayer();                % softmax层
          classificationLayer(),...
    ];

 

这里定义的CNN网络又1个卷积层、1个池化层和1个全连接层组成,其中还使用了批量归一化操作,当然也可以加入dropout来防止过拟合。至于网络结构的设计就属于“炼丹”的范畴了,没有一个统一的标准,因人而异,需要大家自己慢慢体会了。

下面我们简单介绍没有验证集有验证集的区别:

没有验证集情况

% 参数
options = trainingOptions('sgdm',...   % 也可以用adam、rmsprop等方法
    'MaxEpochs',50,...                 % 最大迭代次数
    'Plots','training-progress');
net_cnn = trainNetwork(X_train,Y_train,layers,options);

 Matlab中的深度学习——CNN图像分类实例_第3张图片

 没有验证集曲线,我的电脑就一个GPU。训练非常快。Matlab2022非常强大,自动调用GPU,不需要配置。

没有验证集合训练过程会迭代到最大迭代步数,并且左侧图像中没有验证集的曲线。这里有一个问题,如果我们遇到复杂的问题,模型中又缺少防止过拟合的操作(批量归一化,dropout,正则化或大数据集都可以防止过拟合),那么最终结果极有可能会出现过拟合,这当然是我们不想出现的情况。因此,如何判断模型训练是否“恰到好处”,我们应该引入验证集帮助我们进行判断。

 有验证集的情况

% 参数
options = trainingOptions('sgdm',...                         % 也可以用adam、rmsprop等方法
                          'MiniBatchSize',128, ...
                          'MaxEpochs',50,...                 % 最大迭代次数
                          'ValidationData',{X_val,Y_val},... % 显示验证集误差
                          'Verbose',true, ...                % 命令窗口显示训练过程的各种指标
                          'Shuffle','every-epoch', ...
                          'InitialLearnRate',1e-2,...
                          'Plots','training-progress');
net_cnn = trainNetwork(X_train,Y_train,layers,options);

 Matlab中的深度学习——CNN图像分类实例_第4张图片

options 中我们加入了更多的人为设置的参数,这样稍微高级的定义会使得模型训练的更符合你的要求。加入验证集后,我们可以看到训练过程“早停”了,Epoch只迭代了8次模型就训练过程就结束了,原因如绿框内容所示,Patience为5表示:如果5步迭代中验证集的loss都没有减小的话,训练就终止了。在绝大数情况下,这种早停的机制会防止模型过度训练从而出现过拟合的现象。 

最终我们使用测试集验证模型的泛化能力即可,相应代码与结果如下

% 测试
testLabel = classify(net_cnn,X_test);
precision = sum(testLabel==Y_test)/numel(testLabel);
disp(['测试集分类准确率为',num2str(precision*100),'%'])

在单 GPU 上训练。
正在初始化输入数据归一化。
|=============================================================================|
|  轮  |  迭代  |    经过的时间     |  小批量准确度  |  验证准确度  |  小批量损失  |  验证损失  |  基础学习率  |
|     |      |  (hh:mm:ss)  |          |         |         |        |         |
|=============================================================================|
|   1 |    1 |     00:00:01 |    5.47% |  15.19% |  2.9331 | 2.4121 |  0.0100 |
|   1 |   50 |     00:00:02 |   87.50% |  87.80% |  0.4813 | 0.4667 |  0.0100 |
|   2 |  100 |     00:00:02 |   86.72% |  89.18% |  0.4289 | 0.3996 |  0.0100 |
|   3 |  150 |     00:00:03 |   91.41% |  90.24% |  0.2862 | 0.3596 |  0.0100 |
|   3 |  200 |     00:00:03 |   96.09% |  90.64% |  0.1928 | 0.3462 |  0.0100 |
|   4 |  250 |     00:00:03 |   92.19% |  90.88% |  0.3221 | 0.3343 |  0.0100 |
|   5 |  300 |     00:00:04 |   96.09% |  91.22% |  0.1528 | 0.3249 |  0.0100 |
|   5 |  350 |     00:00:04 |   95.31% |  91.42% |  0.1762 | 0.3231 |  0.0100 |
|   6 |  400 |     00:00:04 |   94.53% |  91.42% |  0.1887 | 0.3138 |  0.0100 |
|   7 |  450 |     00:00:05 |   95.31% |  91.56% |  0.1311 | 0.3226 |  0.0100 |
|   7 |  500 |     00:00:05 |   94.53% |  91.60% |  0.2010 | 0.3162 |  0.0100 |
|   8 |  550 |     00:00:06 |   97.66% |  91.51% |  0.1032 | 0.3149 |  0.0100 |
|   9 |  600 |     00:00:06 |   96.88% |  91.63% |  0.1052 | 0.3155 |  0.0100 |
|   9 |  650 |     00:00:07 |   95.31% |  91.47% |  0.1810 | 0.3241 |  0.0100 |
|  10 |  700 |     00:00:07 |   95.31% |  91.86% |  0.1714 | 0.3145 |  0.0100 |
|  11 |  750 |     00:00:07 |   99.22% |  92.01% |  0.0449 | 0.3117 |  0.0100 |
|  11 |  800 |     00:00:08 |   92.19% |  91.85% |  0.1895 | 0.3241 |  0.0100 |
|  12 |  850 |     00:00:08 |   98.44% |  92.18% |  0.0882 | 0.3132 |  0.0100 |
|  13 |  900 |     00:00:08 |   98.44% |  91.42% |  0.0625 | 0.3320 |  0.0100 |
|  14 |  950 |     00:00:09 |   98.44% |  91.86% |  0.0657 | 0.3219 |  0.0100 |
|  14 | 1000 |     00:00:09 |   95.31% |  91.95% |  0.1356 | 0.3357 |  0.0100 |
|  15 | 1050 |     00:00:10 |   98.44% |  92.15% |  0.1065 | 0.3287 |  0.0100 |
|  16 | 1100 |     00:00:10 |   99.22% |  91.83% |  0.0362 | 0.3364 |  0.0100 |
|  16 | 1150 |     00:00:10 |   97.66% |  92.02% |  0.0595 | 0.3329 |  0.0100 |
|  17 | 1200 |     00:00:11 |  100.00% |  91.67% |  0.0278 | 0.3416 |  0.0100 |
|  18 | 1250 |     00:00:11 |   99.22% |  91.85% |  0.0447 | 0.3399 |  0.0100 |
|  18 | 1300 |     00:00:12 |   97.66% |  91.51% |  0.0715 | 0.3603 |  0.0100 |
|  19 | 1350 |     00:00:12 |   97.66% |  92.01% |  0.0748 | 0.3492 |  0.0100 |
|  20 | 1400 |     00:00:12 |  100.00% |  92.01% |  0.0423 | 0.3480 |  0.0100 |
|  20 | 1450 |     00:00:13 |   98.44% |  91.95% |  0.0614 | 0.3597 |  0.0100 |
|  21 | 1500 |     00:00:13 |   98.44% |  92.01% |  0.0278 | 0.3610 |  0.0100 |
|  22 | 1550 |     00:00:13 |  100.00% |  91.81% |  0.0236 | 0.3707 |  0.0100 |
|  22 | 1600 |     00:00:14 |  100.00% |  92.02% |  0.0307 | 0.3661 |  0.0100 |
|  23 | 1650 |     00:00:14 |  100.00% |  92.17% |  0.0213 | 0.3791 |  0.0100 |
|  24 | 1700 |     00:00:14 |  100.00% |  92.04% |  0.0163 | 0.3731 |  0.0100 |
|  24 | 1750 |     00:00:15 |  100.00% |  91.94% |  0.0230 | 0.3790 |  0.0100 |
|  25 | 1800 |     00:00:15 |   98.44% |  91.90% |  0.0842 | 0.3912 |  0.0100 |
|  26 | 1850 |     00:00:16 |   99.22% |  92.22% |  0.0257 | 0.3819 |  0.0100 |
|  27 | 1900 |     00:00:16 |  100.00% |  91.81% |  0.0135 | 0.4069 |  0.0100 |
|  27 | 1950 |     00:00:16 |  100.00% |  92.11% |  0.0110 | 0.3969 |  0.0100 |
|  28 | 2000 |     00:00:17 |  100.00% |  92.06% |  0.0131 | 0.3956 |  0.0100 |
|  29 | 2050 |     00:00:17 |  100.00% |  91.76% |  0.0183 | 0.4135 |  0.0100 |
|  29 | 2100 |     00:00:17 |  100.00% |  92.22% |  0.0146 | 0.4071 |  0.0100 |
|  30 | 2150 |     00:00:18 |  100.00% |  92.01% |  0.0093 | 0.4090 |  0.0100 |
|  31 | 2200 |     00:00:18 |  100.00% |  92.04% |  0.0128 | 0.4105 |  0.0100 |
|  31 | 2250 |     00:00:18 |   99.22% |  92.10% |  0.0370 | 0.4144 |  0.0100 |
|  32 | 2300 |     00:00:19 |  100.00% |  91.94% |  0.0106 | 0.4172 |  0.0100 |
|  33 | 2350 |     00:00:19 |   99.22% |  91.97% |  0.0360 | 0.4186 |  0.0100 |
|  33 | 2400 |     00:00:20 |  100.00% |  91.95% |  0.0049 | 0.4258 |  0.0100 |
|  34 | 2450 |     00:00:20 |  100.00% |  92.08% |  0.0124 | 0.4201 |  0.0100 |
|  35 | 2500 |     00:00:20 |  100.00% |  91.85% |  0.0109 | 0.4297 |  0.0100 |
|  35 | 2550 |     00:00:21 |   99.22% |  92.22% |  0.0298 | 0.4228 |  0.0100 |
|  36 | 2600 |     00:00:21 |  100.00% |  91.97% |  0.0055 | 0.4364 |  0.0100 |
|  37 | 2650 |     00:00:21 |  100.00% |  91.72% |  0.0104 | 0.4379 |  0.0100 |
|  37 | 2700 |     00:00:22 |  100.00% |  92.06% |  0.0157 | 0.4318 |  0.0100 |
|  38 | 2750 |     00:00:22 |  100.00% |  92.06% |  0.0102 | 0.4319 |  0.0100 |
|  39 | 2800 |     00:00:22 |  100.00% |  91.97% |  0.0042 | 0.4479 |  0.0100 |
|  40 | 2850 |     00:00:23 |  100.00% |  91.95% |  0.0085 | 0.4389 |  0.0100 |
|  40 | 2900 |     00:00:23 |  100.00% |  91.88% |  0.0076 | 0.4460 |  0.0100 |
|  41 | 2950 |     00:00:24 |  100.00% |  92.08% |  0.0077 | 0.4442 |  0.0100 |
|  42 | 3000 |     00:00:24 |  100.00% |  91.94% |  0.0111 | 0.4472 |  0.0100 |
|  42 | 3050 |     00:00:24 |   99.22% |  91.95% |  0.0189 | 0.4486 |  0.0100 |
|  43 | 3100 |     00:00:25 |  100.00% |  91.95% |  0.0076 | 0.4513 |  0.0100 |
|  44 | 3150 |     00:00:25 |  100.00% |  91.97% |  0.0027 | 0.4502 |  0.0100 |
|  44 | 3200 |     00:00:25 |  100.00% |  92.02% |  0.0093 | 0.4604 |  0.0100 |
|  45 | 3250 |     00:00:26 |  100.00% |  91.88% |  0.0069 | 0.4561 |  0.0100 |
|  46 | 3300 |     00:00:26 |  100.00% |  91.92% |  0.0043 | 0.4533 |  0.0100 |
|  46 | 3350 |     00:00:26 |  100.00% |  91.92% |  0.0064 | 0.4670 |  0.0100 |
|  47 | 3400 |     00:00:27 |  100.00% |  91.94% |  0.0073 | 0.4642 |  0.0100 |
|  48 | 3450 |     00:00:27 |  100.00% |  91.92% |  0.0050 | 0.4624 |  0.0100 |
|  48 | 3500 |     00:00:28 |  100.00% |  92.24% |  0.0074 | 0.4618 |  0.0100 |
|  49 | 3550 |     00:00:28 |   99.22% |  91.97% |  0.0167 | 0.4710 |  0.0100 |
|  50 | 3600 |     00:00:28 |  100.00% |  91.86% |  0.0035 | 0.4692 |  0.0100 |
|  50 | 3650 |     00:00:29 |  100.00% |  92.06% |  0.0147 | 0.4677 |  0.0100 |
|=============================================================================|
训练结束: 已完成最大轮数。
测试集分类准确率为92.3364%

 测试集分类准确为92.0961%,效果还可以接受。如果你有兴趣,更换模型拓扑结构、优化算法、激活函数等等,可能会得到更好的分类结果。

 

今天分享的例子大家按照流程一步一步操作,肯定都能跑通,使用Matlab做深度学习适应于于以下人群:

1.熟悉Matlab但不熟悉Python语言的小伙伴;

2.刚刚接触深度学习,想要将理论迅速实现的小伙伴;

3.不需要state-of-art模型的,只需要将经典深度学习模型应用到自己研究领域的小伙伴

其中以上分享的Matlab中深度学习的代码,与Keras很像,都是“搭积木”的形式将模型拼凑起来而已,但是如果真的想进一步深入研究深度学习,或者是对现有模型进行改进创新,那么Matlab中提供的工具包集成度就显得太高了,pytorch或tensorflow或许更合适。但是不过怎么说,Matlab至少在深度学习领域迈出了一大步,可喜可贺。

如果对你有用,欢迎点赞收藏,加关注,后面会持续免费更新新内容

参考:

Matlab中的深度学习——CNN图像分类实例 - 知乎 (zhihu.com)

你可能感兴趣的:(MATLAB,CUDA,matlab,深度学习,cnn)