深度学习DEMO提供了三个实现目标识别的卷积神经网络CNN示例。三个例子分别为:
从零开始学习如何建立CNN;
使用已经训练过的模型(迁移学习);
用于特征提取的神经网络训练。
每个DEMO都有对应的视频讲解,视频地址为:https://www.mathworks.com/videos/series/deep-learning-with-MATLAB.html
运行以上示例需要安装MATLAB自带的GPU和并行计算工具箱,DEMO 3还需要安装统计与机器学习工具箱。
下面简单介绍DEMO 1:从零开始学习如何建立CNN。
运行DownloadCIFAR10.m文件,下载DEMO运行所需要的数据。
执行以下代码将训练数据导入MATLAB;
%Please note: these are 4 of the 10 categories available
%Feel free to choose which ever you like best!
categories= {‘Deer’,‘Dog’,‘Frog’,‘Cat’};
rootFolder= ‘cifar10Train’;
imds= imageDatastore(fullfile(rootFolder, categories), …
'LabelSource', 'foldernames');
定义CNN的各层网络,这里可以根据自己的需要调整参数,下面的代码只是一个示例。
varSize= 32;
conv1= convolution2dLayer(5,varSize,‘Padding’,2,‘BiasLearnRateFactor’,2);
conv1.Weights= gpuArray(single(randn([5 5 3 varSize])*0.0001));
fc1= fullyConnectedLayer(64,‘BiasLearnRateFactor’,2);
fc1.Weights= gpuArray(single(randn([64 576])*0.1));
fc2= fullyConnectedLayer(4,‘BiasLearnRateFactor’,2);
fc2.Weights= gpuArray(single(randn([4 64])*0.1));
layers= [
imageInputLayer([varSize varSize 3]);
conv1;
maxPooling2dLayer(3,'Stride',2);
reluLayer();
convolution2dLayer(5,32,'Padding',2,'BiasLearnRateFactor',2);
reluLayer();
averagePooling2dLayer(3,'Stride',2);
convolution2dLayer(5,64,'Padding',2,'BiasLearnRateFactor',2);
reluLayer();
averagePooling2dLayer(3,'Stride',2);
fc1;
reluLayer();
fc2;
softmaxLayer()
classificationLayer()];
设置CNN的训练选项,这些参数设置会严重影响CNN的工作性能,在设置之前应当准确理解这些参数的物理意义。
opts= trainingOptions(‘sgdm’, …
'InitialLearnRate', 0.001, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.1, ...
'LearnRateDropPeriod', 8, ...
'L2Regularization', 0.004, ...
'MaxEpochs', 10, ...
'MiniBatchSize', 100, ...
'Verbose', true);
开始训练CNN,训练时间长短与具体的硬件设备相关,一般会花费数分钟或以上。
[net, info] =trainNetwork(imds, layers, opts);
Training on singleGPU.
Initializing imagenormalization.
|=========================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning|
| | | (seconds) | Loss | Accuracy | Rate |
|=========================================================================================|
| 1 | 1 | 0.25 | 1.3862 | 24.00% | 0.0010 |
| 1 | 50 | 1.86 | 1.2571 | 39.00% | 0.0010 |
| 1 | 100 | 3.35 | 1.2376 | 39.00% | 0.0010 |
| 1 | 150 | 4.90 | 1.1451 | 50.00% | 0.0010 |
| 1 | 200 | 6.39 | 1.0797 | 59.00% | 0.0010 |
| 2 | 250 | 8.03 | 0.8069 | 69.00% | 0.0010 |
| 2 | 300 | 9.64 | 1.1253 | 51.00% | 0.0010 |
| 2 | 350 | 11.20 | 0.9872 | 59.00% | 0.0010 |
| 2 | 400 | 12.75 | 0.9490 | 59.00% | 0.0010 |
| 3 | 450 | 14.31 | 0.7405 | 70.00% | 0.0010 |
| 3 | 500 | 15.77 | 0.9592 | 59.00% | 0.0010 |
| 3 | 550 | 17.28 | 0.9337 | 61.00% | 0.0010 |
| 3 | 600 | 18.77 | 0.8383 | 65.00% | 0.0010 |
| 4 | 650 | 20.30 | 0.6693 | 71.00% | 0.0010 |
| 4 | 700 | 21.80 | 0.8787 | 63.00% | 0.0010 |
| 4 | 750 | 23.27 | 0.8892 | 63.00% | 0.0010 |
| 4 | 800 | 24.76 | 0.7295 | 69.00% | 0.0010 |
| 5 | 850 | 26.28 | 0.6321 | 72.00% | 0.0010 |
| 5 | 900 | 27.77 | 0.8034 | 71.00% | 0.0010 |
| 5 | 950 | 29.26 | 0.8285 | 68.00% | 0.0010 |
| 5 | 1000 | 30.75 | 0.6893 | 69.00% | 0.0010 |
| 6 | 1050 | 32.27 | 0.5741 | 76.00% | 0.0010 |
| 6 | 1100 | 33.74 | 0.7280 | 73.00% | 0.0010 |
| 6 | 1150 | 35.20 | 0.8312 | 68.00% | 0.0010 |
| 6 | 1200 | 36.69 | 0.5876 | 77.00% | 0.0010 |
| 7 | 1250 | 38.25 | 0.5598 | 75.00% | 0.0010 |
| 7 | 1300 | 39.80 | 0.6704 | 77.00% | 0.0010 |
| 7 | 1350 | 41.37 | 0.7792 | 68.00% | 0.0010 |
| 7 | 1400 | 42.87 | 0.5495 | 78.00% | 0.0010 |
| 8 | 1450 | 44.40 | 0.5561 | 79.00% | 0.0010 |
| 8 | 1500 | 45.89 | 0.6032 | 81.00% | 0.0010 |
| 8 | 1550 | 47.39 | 0.7548 | 68.00% | 0.0010 |
| 8 | 1600 | 48.90 | 0.5371 | 78.00% | 0.0010 |
| 9 | 1650 | 50.49 | 0.5247 | 80.00% | 0.0001 |
| 9 | 1700 | 52.02 | 0.5989 | 79.00% | 0.0001 |
| 9 | 1750 | 53.60 | 0.6982 | 72.00% | 0.0001 |
| 9 | 1800 | 55.17 | 0.4448 | 78.00% | 0.0001 |
| 10 | 1850 | 56.71 | 0.4927 | 79.00% | 0.0001 |
| 10 | 1900 | 58.23 | 0.5630 | 80.00% | 0.0001 |
| 10 | 1950 | 59.71 | 0.6843 | 73.00% | 0.0001 |
| 10 | 2000 | 61.18 | 0.4486 | 79.00% | 0.0001 |
|=========================================================================================|
将测试验证数据导入MATLAB。
rootFolder= ‘cifar10Test’;
imds_test= imageDatastore(fullfile(rootFolder, categories), …
'LabelSource', 'foldernames');
测试结果输出,通过随机读取一幅图片进行分类测试,如果图片的标题为绿色,则预测结果正确;如果为红色,则预测结果错误。
labels= classify(net, imds_test);
ii= randi(4000);
im= imread(imds_test.Files{ii});
imshow(im);
iflabels(ii) ==imds_test.Labels(ii)
colorText = ‘g’;
else
colorText = 'r';
end
title(char(labels(ii)),‘Color’,colorText);
DEMO下载地址:
http://page2.dfpan.com/fs/9lc2j2821f29b1676d7/