matlab学习笔记——DeepLearnToolBox搭建MNIST识别网络

DeepLearnToolBox是matlab下的一个简单的深度学习工具包,接口简单易用,其代码是纯matlab编写。

使用过程非常简单,总共分两步:

  1. 在github上下载代码;
  2. 打开matlab,在matlab命令行窗口中输入:addpath(genpath('所在文件夹\DeepLearnToolbox'));

然后就可以愉快地敲代码了,下面是一个用于识别MNIST手写数字的官方示例:

function test_example_CNN
load mnist_uint8; % 加载手写数字

% 处理数据
train_x = double(reshape(train_x',28,28,60000))/255;
test_x = double(reshape(test_x',28,28,10000))/255;
train_y = double(train_y');
test_y = double(test_y');
%% 建立一个卷积神经网络
% 跑一次循环需要200秒,一个epoch可以获得11%的误差;
% 100 个epochs 之后可以获得1.2%的误差。
rand('state',0)
% 网络结构
cnn.layers = {
    struct('type', 'i')  % 输入层
    struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5)  % 卷积层
    struct('type', 's', 'scale', 2)  % 上采样
    struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) % 卷积层
    struct('type', 's', 'scale', 2) % 上采样
};
% 网络初始化
cnn = cnnsetup(cnn, train_x, train_y);

% 参数
opts.alpha = 1;
opts.batchsize = 50;
opts.numepochs = 1;

% 训练
cnn = cnntrain(cnn, train_x, train_y, opts);

% 验证误差
[er, bad] = cnntest(cnn, test_x, test_y);

% 打印均方误差
figure; plot(cnn.rL);

% 如果er>=0.12 则报错
assert(er<0.12, 'Too big error');

你可能感兴趣的:(matlab学习笔记——DeepLearnToolBox搭建MNIST识别网络)