这个实验拿来入门最好了,很轻松,执行作者写好的脚本就可以;
官网链接:http://caffe.berkeleyvision.org/gathered/examples/mnist.html
跟着这个实验,可以观察caffe使用的流程,深入一点,可以看看作者写的shell脚本做了哪些事情;
下面我们一步步来执行,并作简单的分析:
1 首先准备数据集
执行下面的命令。$CAFFE_ROOT是指caffe安装的目录,比如说我的是/home/sloanqin/caffe-master/
cd $CAFFE_ROOT ./data/mnist/get_mnist.sh //执行下载mnist数据集脚本 ./examples/mnist/create_mnist.sh //将下载的mnist数据集转换成lmdb格式的数据
我们来看看这两个脚本做了什么:
1.1 执行脚本get_mnist.sh:
get_mnist.sh脚本文件如下:
<pre name="code" class="plain"># This scripts downloads the mnist data and unzips it. DIR="$( cd "$(dirname "$0")" ; pwd -P )" cd $DIR echo "Downloading..." wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz echo "Unzipping..." gunzip train-images-idx3-ubyte.gz gunzip train-labels-idx1-ubyte.gz gunzip t10k-images-idx3-ubyte.gz gunzip t10k-labels-idx1-ubyte.gz # Creation is split out because leveldb sometimes causes segfault # and needs to be re-created. echo "Done."
作用:下载下面四个文件,并解压;
train-images-idx3-ubyte.gz训练集 train-labels-idx1-ubyte.gz训练集的标签 t10k-images-idx3-ubyte.gz测试集 t10k-labels-idx1-ubyte.gz测试集的标签
2 定义caffe的网络结构:.prototxt 文件
caffe的网络结构定义在后缀名为.prototxt的文件中,我们根据自己的需要定义自己的网络结构;
在这个实验中,我们使用作者已经为我们定义好的lenet网络结构,大家可以在下面的目录中找到该文件:
$CAFFE_ROOT/examples/mnist/lenet_train_test.prototxt
在我的电脑上,目录是/home/sloanqin/examples/mnist/lenet_train_test.prototxt
在后续的工作中,定义好自己的网络结构是最关键的,直接决定了性能,这里我们就不多说了;
3 定义caffe运算的时候的一些规则:solver.prototxt 文件
该文件在下面的目录中:
$CAFFE_ROOT/examples/mnist/lenet_solver.prototxt
文件内容如下:作者给出了英文注释,我再给出中文的注释
# The train/test net protocol buffer definition net: "examples/mnist/lenet_train_test.prototxt" # test_iter specifies how many forward passes the test should carry out. # In the case of MNIST, we have test batch size 100 and 100 test iterations, # covering the full 10,000 testing images. test_iter: 100 # Carry out testing every 500 training iterations. test_interval: 500 # The base learning rate, momentum and the weight decay of the network. base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 # The learning rate policy lr_policy: "inv" gamma: 0.0001 power: 0.75 # Display every 100 iterations display: 100 # The maximum number of iterations max_iter: 10000 # snapshot intermediate results snapshot: 5000 snapshot_prefix: "examples/mnist/lenet" # solver mode: CPU or GPU solver_mode: GPU中文注释版本:
<pre name="code" class="plain"># The train/test net protocol buffer definition net: "examples/mnist/lenet_train_test.prototxt" # test_iter specifies how many forward passes the test should carry out. # In the case of MNIST, we have test batch size 100 and 100 test iterations, # covering the full 10,000 testing images. test_iter: 100 // 这个参数指定测试的时候送入多少个 // 这里说明一个知识:GPU在计算的时候,每次迭代是多张图片,我们称为一个batch // 作者提到:test batch size 100,就是说每个包有100张图片 // 这里设置 st_iter=100,就是测试的时候一共输入100*100=10000张图片 //所以,test_iter 的英文翻译就是:测试时迭代次数 # Carry out testing every 500 training iterations. test_interval: 500 //定义每500次迭代,做一次测试 # The base learning rate, momentum and the weight decay of the network. base_lr: 0.01 // 定义了刚开始的学习速率是0.01 momentum: 0.9 weight_decay: 0.0005 # The learning rate policy lr_policy: "inv" gamma: 0.0001 power: 0.75 # Display every 100 iterations display: 100 //每迭代100次,显示一次计算结果 # The maximum number of iterations max_iter: 10000 //设置最大的迭代次数 # snapshot intermediate results snapshot: 5000 // 保存中间运行时得到的参数结果,这里设置成每5000次迭代保存一次,这样运行中间断掉了,我们可以从断掉的地方继续开始 snapshot_prefix: "examples/mnist/lenet" # solver mode: CPU or GPU solver_mode: GPU //使用CPU还是GPU计算4 执行命令进行训练
最后一步就是执行脚本开始训练:
cd $CAFFE_ROOT
./examples/mnist/train_lenet.sh
我们打开这个脚本,可以看到特别简单,就一行:
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt
这行代码的意思:调用./build/tools/caffe目录下面的train函数,train函数的输入参数是solver.prototxt文件的路径:--solver=examples/mnist/lenet_solver.prototxt
5 结果
运行的过程中,可以卡到test 的准确率在不断上升;运行结束后,会生成模型文件:lenet_iter_10000.caffemodel
还有一个文件是snapshot 保存的:lenet_iter_10000.solverstate
原文链接:http://write.blog.csdn.net/postedit/49147935