本文为原创文章转载必须注明本文出处以及附上 本文地址超链接 以及 博主博客地址:http://blog.csdn.net/qq_20259459 和 作者邮箱( [email protected] )。
(如果喜欢本文,欢迎大家关注我的博客或者动手点个赞,有需要可以邮件联系我)
接上一篇文章(阅读上一篇文章:http://blog.csdn.net/qq_20259459/article/details/54293054),我还是决定给大家写一下相关代码的注释。希望给大家带来帮助。
(一):cnn_mnist.m
function [net, info] = cnn_mnist(varargin)
%% --------------------------------------------------------------
% 主函数:cnn_mnist
% 功能: 1.初始化CNN
% 2.设置各项参数
% 3.读取和保存数据集
% 4.初始化train
% ------------------------------------------------------------------------
%CNN_MNIST Demonstrates MatConvNet on MNIST
%运行matlab文件夹下的vl_setupnn.m
run('C:\Users\Desktop\matconvnet-1.0-beta23\matconvnet-1.0-beta23\matlab/vl_setupnn.m') ;
opts.batchNormalization = false ; %选择batchNormalization的真假
opts.network = [] ; %初始化一个网络
opts.networkType = 'simplenn' ; %选择网络结构 %%% simplenn %%% dagnn
[opts, varargin] = vl_argparse(opts, varargin) ; %调用vl_argparse函数
sfx = opts.networkType ; %sfx=simplenn
if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end %这里条件为假
opts.expDir = fullfile(vl_rootnn, 'data', ['mnist-baseline-' sfx]) ; %选择数据存放的路径:data\mnist-baseline-simplenn
[opts, varargin] = vl_argparse(opts, varargin) ; %调用vl_argparse函数
opts.dataDir = fullfile(vl_rootnn, 'data', 'mnist') ; %选择数据读取的路径:data\matconvnet-1.0-beta23\data\mnist
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); %选择imdb结构体的路径:data\data\mnist-baseline-simplenn\imdb
opts.train = struct() ; %选择训练集返回为struct型
opts = vl_argparse(opts, varargin) ; %调用vl_argparse函数
%选择是否使用GPU,使用opts.train.gpus = 1,不使用:opts.train.gpus = []。
%有关GPU的安装配置请看我的博客:http://blog.csdn.net/qq_20259459/article/details/54093550
if ~isfield(opts.train, 'gpus'), opts.train.gpus = 1; end;
% --------------------------------------------------------------------
% 准备网络
% --------------------------------------------------------------------
if isempty(opts.network) %如果原网络为空:
net = cnn_mnist_init('batchNormalization', opts.batchNormalization, ... % 则调用cnn_mnist_init网络结构
'networkType', opts.networkType) ;
else %否则:
net = opts.network ; % 使用上面选择的数值带入现有网络
opts.network = [] ;
end
% --------------------------------------------------------------------
% 准备数据
% --------------------------------------------------------------------
if exist(opts.imdbPath, 'file') %如果mnist中存在imdb的结构体:
imdb = load(opts.imdbPath) ; % 载入imdb
else %否则:
imdb = getMnistImdb(opts) ; % 调用getMnistImdb函数得到imdb并保存
mkdir(opts.expDir) ;
save(opts.imdbPath, '-struct', 'imdb') ;
end
%arrayfun函数通过应用sprintf函数得到array中从1到10的元素并且将其数字标签转化为char文字型
net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:10,'UniformOutput',false) ;
% --------------------------------------------------------------------
% 开始训练
% --------------------------------------------------------------------
switch opts.networkType %选择网络类型:
case 'simplenn', trainfn = @cnn_train ; % 1.simplenn
case 'dagnn', trainfn = @cnn_train_dag ; % 2.dagnn
end
[net, info] = trainfn(net, imdb, getBatch(opts), ... %调用训练函数,开始训练:find(imdb.images.set == 3)为验证集的样本
'expDir', opts.expDir, ...
net.meta.trainOpts, ...
opts.train, ...
'val', find(imdb.images.set == 3)) ;
% ------------------------------------------------------------------------
function fn = getBatch(opts)
%% --------------------------------------------------------------
% 函数名:getBatch
% 功能: 1.由opts返回函数
% 2.从imdb结构体取出数据
% 备注: 如果不理解Batc的意义的话,请查看我的博客:http://blog.csdn.net/qq_20259459/article/details/53943413
% ------------------------------------------------------------------------
switch lower(opts.networkType) %根据网络类型使用不同的getBatcch
case 'simplenn'
fn = @(x,y) getSimpleNNBatch(x,y) ;
case 'dagnn'
bopts = struct('numGpus', numel(opts.train.gpus)) ;
fn = @(x,y) getDagNNBatch(bopts,x,y) ;
end
% --------------------------------------------------------------------
function [images, labels] = getSimpleNNBatch(imdb, batch)
%% --------------------------------------------------------------
% 函数名:getSimpleNNBatch
% 功能: 1.由SimpleNN网络的批得到函数
% 2.batch为样本的索引值
% ------------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ; %返回训练集
labels = imdb.images.labels(1,batch) ; %返回集标签
% --------------------------------------------------------------------
function inputs = getDagNNBatch(opts, imdb, batch)
%% --------------------------------------------------------------
% 函数名:getDagNNBatch
% 功能: 类似上面的函数,这里的网络结构是DagNN
% ------------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;
if opts.numGpus > 0 %使用GPU进行并行运算
images = gpuArray(images) ;
end
inputs = {'input', images, 'label', labels} ;
% --------------------------------------------------------------------
function imdb = getMnistImdb(opts)
%% --------------------------------------------------------------
% 函数名:getMnistImdb
% 功能: 1.从mnist数据集中获取data
% 2.将得到的数据减去mean值
% 3.将处理后的数据存放如imdb结构中
% ------------------------------------------------------------------------
% Preapre the imdb structure, returns image data with mean image subtracted
files = {'train-images-idx3-ubyte', ... %载入mnist数据集
'train-labels-idx1-ubyte', ...
't10k-images-idx3-ubyte', ...
't10k-labels-idx1-ubyte'} ;
if ~exist(opts.dataDir, 'dir') %如果不存在读取路径:
mkdir(opts.dataDir) ; % 建立读取路径
end
for i=1:4 %如果不存在mnist数据集则下载
if ~exist(fullfile(opts.dataDir, files{i}), 'file')
url = sprintf('http://yann.lecun.com/exdb/mnist/%s.gz',files{i}) ;
fprintf('downloading %s\n', url) ;
gunzip(url, opts.dataDir) ;
end
end
f=fopen(fullfile(opts.dataDir, 'train-images-idx3-ubyte'),'r') ; %载入第一个文件,训练数据集大小为28*28,数量为6万
x1=fread(f,inf,'uint8');
fclose(f) ;
x1=permute(reshape(x1(17:end),28,28,60e3),[2 1 3]) ; %通过permute函数将数组的维度由原来的[1 2 3]变为[2 1 3] ...
%reshape将原数据从第17位开始构成28*28*60000的数组
f=fopen(fullfile(opts.dataDir, 't10k-images-idx3-ubyte'),'r') ; %载入第二个文件,测试数据集大小为28*28,数量为1万
x2=fread(f,inf,'uint8');
fclose(f) ;
x2=permute(reshape(x2(17:end),28,28,10e3),[2 1 3]) ; %同上解释
f=fopen(fullfile(opts.dataDir, 'train-labels-idx1-ubyte'),'r') ; %载入第三个文件:训练数据集的类标签
y1=fread(f,inf,'uint8');
fclose(f) ;
y1=double(y1(9:end)')+1 ;
f=fopen(fullfile(opts.dataDir, 't10k-labels-idx1-ubyte'),'r') ; %载入第四个文件:测试数据集的类标签
y2=fread(f,inf,'uint8');
fclose(f) ;
y2=double(y2(9:end)')+1 ;
%set = 1 对应训练;set = 3 对应的是测试
set = [ones(1,numel(y1)) 3*ones(1,numel(y2))]; %numel返回元素的总数
data = single(reshape(cat(3, x1, x2),28,28,1,[])); %将x1的训练数据集和x2的测试数据集的第三个维度进行拼接组成新的数据集,并且转为single型减少内存
dataMean = mean(data(:,:,:,set == 1), 4); %求出训练数据集中所有的图像的均值
data = bsxfun(@minus, data, dataMean) ; %利用bsxfun函数将数据集中的每个元素逐个减去均值
%将数据存入imdb结构中
imdb.images.data = data ; %data的大小为[28 28 1 70000]。 (60000+10000)
imdb.images.data_mean = dataMean; %dataMean的大小为[28 28]
imdb.images.labels = cat(2, y1, y2) ; %拼接训练数据集和测试数据集的标签,拼接后的大小为[1 70000]
imdb.images.set = set ; %set的大小为[1 70000],unique(set) = [1 3]
imdb.meta.sets = {'train', 'val', 'test'} ; %imdb.meta.sets=1用于训练,imdb.meta.sets=2用于验证,imdb.meta.sets=3用于测试
%arrayfun函数通过应用sprintf函数得到array中从0到9的元素并且将其数字标签转化为char文字型
imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ;
本文为原创文章转载必须注明本文出处以及附上 本文地址超链接 以及 博主博客地址:http://blog.csdn.net/qq_20259459 和 作者邮箱( [email protected] )。
(如果喜欢本文,欢迎大家关注我的博客或者动手点个赞,有需要可以邮件联系我)