用tiny-dnn完成mnist手写数字识别

tiny-dnn是一个轻量级神经网络框架,相对于caffe、tensorflow等框架它最大的特点是依赖少,易于部署,缺点是不支持GPU,无法训练大型的神经网络。 源码可在github下载https://github.com/tiny-dnn/tiny-dnn

用tiny-dnn完成mnist手写数字识别_第1张图片

 下载解压后在其vc文件夹中可以看到tiny-dnn的示例工程

用tiny-dnn完成mnist手写数字识别_第2张图片

打开工程后可以看到6个官方demo

void sample1_convnet(const string& data_dir = "../../data");
void sample2_mlp(const string& data_dir = "../../data");
void sample3_dae();
void sample4_dropout(const string& data_dir = "../../data");
void sample5_unbalanced_training_data(const string& data_dir = "../../data");
void sample6_graph();

 这里我们先运行第一个demo,其功能是使用lenet-5卷积神经网络完成mnist手写数字识别

void sample1_convnet(const string& data_dir) {
    // construct LeNet-5 architecture
    network nn;
    adagrad optimizer;

    // connection table [Y.Lecun, 1998 Table.1] 定义lenet-5网络结构
#define O true
#define X false
    static const bool connection[] = {
        O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
        O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
        O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
        X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
        X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
        X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
    };
#undef O
#undef X

    nn << convolutional_layer(
            32, 32, 5, 1, 6)  /* 32x32 in, 5x5 kernel, 1-6 fmaps conv */
       << average_pooling_layer(
            28, 28, 6, 2)     /* 28x28 in, 6 fmaps, 2x2 subsampling */
       << convolutional_layer(
            14, 14, 5, 6, 16, connection_table(connection, 6, 16))
       << average_pooling_layer(10, 10, 16, 2)
       << convolutional_layer(5, 5, 5, 16, 120)
       << fully_connected_layer(120, 10);

    std::cout << "load models..." << std::endl;

    // load MNIST dataset 加载mnist数据集
    std::vector train_labels, test_labels;
    std::vector   train_images, test_images;

    std::string train_labels_path = data_dir + "/train-labels.idx1-ubyte";
    std::string train_images_path = data_dir + "/train-images.idx3-ubyte";
    std::string test_labels_path  = data_dir + "/t10k-labels.idx1-ubyte";
    std::string test_images_path  = data_dir + "/t10k-images.idx3-ubyte";

    parse_mnist_labels(train_labels_path, &train_labels);
    parse_mnist_images(train_images_path, &train_images, -1.0, 1.0, 2, 2);
    parse_mnist_labels(test_labels_path,  &test_labels);
    parse_mnist_images(test_images_path,  &test_images, -1.0, 1.0, 2, 2);

    std::cout << "start learning" << std::endl;

    progress_display disp(train_images.size());
    timer t;
    int minibatch_size = 10; //训练的batch size设置为10

    optimizer.alpha *= std::sqrt(minibatch_size); //设置学习率

    // create callback
    auto on_enumerate_epoch = [&](){
        std::cout << t.elapsed() << "s elapsed." << std::endl; //输出每轮迭代的时间

        tiny_dnn::result res = nn.test(test_images, test_labels); 

        std::cout << res.num_success << "/" << res.num_total << std::endl; //输出测试集正确率

        disp.restart(train_images.size()); //开始下一轮迭代
        t.restart();
    };

    auto on_enumerate_minibatch = [&](){
        disp += minibatch_size;
    };

    // training
    nn.train(optimizer, train_images, train_labels, minibatch_size, 20, //共计迭代20轮
                  on_enumerate_minibatch, on_enumerate_epoch);

    std::cout << "end training." << std::endl;

    // test and show results
    nn.test(test_images, test_labels).print_detail(std::cout);

    // save networks
    std::ofstream ofs("LeNet-weights"); //保存训练结果
    ofs << nn;
}

运行代码,得到训练结果:

用tiny-dnn完成mnist手写数字识别_第3张图片

经过20次迭代后,mnist测试集识别率达到99.02%

你可能感兴趣的:(字符识别)