cd data/mnist/
./get_mnist.sh
get_mnist.sh该脚本用于下载MNIST数据集并解压
原始数据集包括四个文件
#include
#include
#include
#include
#include // NOLINT(readability/streams)
#include
#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
using std::string;
DEFINE_int32(rows, 25, "The rows of index in image");
DEFINE_int32(cols, 40, "The cols of index in image");
DEFINE_int32(offset, 0, "The offset of index in raw image");
//大端模式小端模式转换
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
//数据集转换函数,输入参数:MNIST数据集文件,图片文件
void convert_image(const char* image_filename, const char* png_filename) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
CHECK(image_file) << "Unable to open file " << image_filename;
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t rows;
uint32_t cols;
//读取魔数
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
//读取数据条目总数
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
//读取行数
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
//读取列数
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
//命令行参数读取
const int flag_rows = FLAGS_rows;
const int flag_cols = FLAGS_cols;
const int offset = FLAGS_offset;
const int width = flag_cols*cols;
const int height = flag_rows*rows;
char* pixels = new char[rows * cols];
cv::Mat tp = cv::Mat::zeros(height, width, CV_8UC1);
//使用读取MINST数据,写入到opencv中的Mat类对象中
image_file.seekg(offset*rows*cols, std::ios::cur);
for(int i=0; ifor(int j=0; jif(!image_file.eof()) {
image_file.read(pixels, rows * cols);
for(int k=0; kfor(int l=0; l(k + i*rows, j*cols + l) = (int)pixels[k*cols+l];
}
}
}
else {
for(int k=0; kfor(int l=0; l(k + i*rows, j*cols + l) = 0;
}
}
}
}
}
//调用opencv中的函数保存图片
cv::imwrite(png_filename, tp);
}
int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
FLAGS_alsologtostderr = 1;
// 设设置命令行参数帮助信息
gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
"image(png) format.\n"
"Usage:\n"
" convert_mnist_data [FLAGS] input_image_file "
"output_png_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading,"
"or directly use data/mnist/get_mnist.sh\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc != 3) {
gflags::ShowUsageWithFlagsRestrict(argv[0],
"examples/mnist/convert_mnist_data");
} else {
google::InitGoogleLogging(argv[0]);
//转换图片
convert_image(argv[1], argv[2]);
}
return 0;
}
下载到的原始数据集为二进制文件,需要转换为LEVELDB或LMDB才能被caffe识别
所以需要运行脚本
./examples/mnist/create_mnist.sh
此时在examples/mnist里生成了mnist_train_lmdb和mnist_test_lmdb两个目录,每个目录下都有data.mdb和lock.mdb
examples/mnist/train_lenet.sh
使用CPU模式运行
打印训练超参数文件examples/mnist/lenet_solver.prototxt
脚本中有指定CNN网络描述文件
解析CNN网络描述文件中的网络参数,创建训练网络
训练mnist
产生两个输出,data为图片数据,label为标签数据
打开训练lmdb,累计增加
创建中间层
最后一层loss
创建测试网络
添加accuracy
迭代次数增加,loss下降
获得最终loss值和accuracy值
./build/tools/caffe.bin test \
-model examples/mnist/lenet_train_test.prototxt \
-weights examples/mnist/lenet_iter_10000.caffemodel \
-iterations 100
import numpy as np
import struct
import matplotlib.pyplot as plt
import Image
filename = 't10k-images-idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()
index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)
index += struct.calcsize('>IIII')
for image in range(0, numImages):
im = struct.unpack_from('>784B', buf, index)
index += struct.calcsize('>784B')
im = np.array(im, dtype='uint8')
im = im.reshape(28, 28)
im = Image.fromarray(im)
im.save('data/mnist/mnist_train/train_%s.bmp' % image, 'bmp')
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import caffe
caffe_root = '/home/caffe'
sys.path.insert(0, caffe_root + 'python')
MODEL_FILE = '../mnist/lenet.prototxt'
PRETRAINED = '../mnist/lenet_iter_10000.caffemodel'
IMAGE_FILE = 'demo.bmp'
input_image = caffe.io.load_image(IMAGE_FILE,color=False)
net = caffe.Classifier(MODEL_FILE,PRETRAINED)
prediction = net.predict([input_image], oversample = False)
caffe.set_mode_cpu()
print 'predicted class:',prediction[0].argmax()
图片
测试结果