caffe简易上手—— 运行cifar例子

训练cifar训练集

cifar是一个常见的图像分类训练集,包括上万张图片及20个分类,caffe提供了一个网络用于分类cifar数据集。

cifar网络的定义在examples/cifar10目录下,训练的过程十分简单。

(以下命令均在caffe默认根目录下运行,下同)

 

1、获取训练数据

cd $CAFFE_ROOT
./data/cifar10/get_cifar10.sh
./examples/cifar10/create_cifar10.sh

 

2、开始训练

cd $CAFFE_ROOT
./examples/cifar10/train_quick.sh

 

3、训练完成后我们会得到:

  cifar10_quick_iter_4000.caffemodel.h5

  cifar10_quick_iter_4000.solverstate.h5

  此时,我们就训练得到了模型,用于后面的分类。

 

4、下面我们使用模型来分类新数据

先直接用一下别人的模型分类试一下:(默认用的ImageNet的模型)

python python/classify.py examples/images/cat.jpg foo

 

下面我们来指定自己的模型进行分类:

python python/classify.py --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --center_only  examples/images/cat.jpg foo

上面这句话的意思是,使用cifar10_quick.prototxt网络 + cifar10_quick_iter_4000.caffemodel.h5模型,对examples/images/cat.jpg图片进行分类。

 

默认的classify脚本不会直接输出结果,而是会把结果输入到foo文件里,不太直观,这里我在网上找了一个修改版,添加了一些参数,可以输出概率最高的分类。

替换python/classify.py,下载地址:http://download.csdn.net/detail/caisenchuan/9513196

 

这个脚本添加了两个参数,可以指定labels_file,然后可以直接把分类结果输出出来:

python python/classify.py --print_results --model_def examples/cifar10/cifar10_quick.prototxt --pretrained_model examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5 --labels_file data/cifar10/cifar10_words.txt  --center_only  examples/images/cat.jpg foo

输出结果:

复制代码
Loading file: examples/images/cat.jpg
Classifying 1 inputs.
predict 3 inputs.
Done in 0.02 s.
Predictions : [[ 0.03903743  0.00722749  0.04582177  0.44352672  0.01203315  0.11832549
   0.02335102  0.25013766  0.03541689  0.02512246]]
python/classify.py:176: FutureWarning: sort(columns=....) is deprecated, use sort_values(by=.....)
  labels = labels_df.sort('synset_id')['name'].values
[('cat', '0.44353'), ('horse', '0.25014'), ('dog', '0.11833'), ('bird', '0.04582'), ('airplane', '0.03904')]
上面标明了各个分类的顺序和置信度
Saving results into foo
复制代码

 

Tips

最后,总结一下训练一个网络用到的相关文件:

cifar10_quick_solver.prototxt:方案配置,用于配置迭代次数等信息,训练时直接调用caffe train指定这个文件,就会开始训练

cifar10_quick_train_test.prototxt:训练网络配置,用来设置训练用的网络,这个文件的名字会在solver.prototxt里指定

cifar10_quick_iter_4000.caffemodel.h5:训练出来的模型,后面就用这个模型来做分类

cifar10_quick_iter_4000.solverstate.h5:也是训练出来的,应该是用来中断后继续训练用的文件

cifar10_quick.prototxt:分类用的网络

你可能感兴趣的:(caffe)