训练之前,确认caffe已经编译成功。本文以mstar数据库为例,介绍如何在caffe平台上训练卷及神经网络模型并进行测试。
一、准备数据
这个步骤是最繁琐,也最容易出错的一步,任何差池都会导致最终训练效果不如人意,
建议多花时间检查这部分,确保数据集的质量。
1、在data文件夹下新建文件夹mstar,并进入该目录,分别简历train和val文件夹。
2、在train和val文件夹下分别建立n个文件夹,n代表目标的类别数目。比如我要对mstar中10类目标进行分类,则分别简历10个文件夹,命名为0~9。
3、在各个文件夹下储存各个类别的目标图片,包括train和val。注意图片的命名,一定不要有重复。为方便,可以在windows下用看图王等批处理软件,对命名进行修改。例如可以改为’120100.jpg‘等,表示该图片为验证集第3类(存在第0类)的第100张图片。
4、在mstar文件夹下新建train.txt和val.txt,内容如下,表示图片及其对应的类别。这部分是挺繁琐的,可以在网上搜索批处理的方法,例如filepath2traintxt.sh以及filepath2valtxt.sh。
100001.jpeg0
100002.jpeg0
100003.jpeg0
100004.jpeg0
...
二、制作数据
给caffe的图片格式必须是lmdb格式的,caffe自带了一些小工具可以很方便的使用。
1、在examples文件夹下复制cifar10文件夹,重新命名为mstar,文件夹下的文件名中cifar10统一替换为mstar。
2、修改create_mstar.sh中的路径以及resize的尺寸,并执行,可以生成mstar_train_lmdb和mstar_val_lmdb两个文件夹。
3、如果需要均值文件,可以复制examples/imagenet/make_imagenet_mean.sh进行路径的修改,并执行,可以生成mstar_mean.binaryproto文件。
三、训练
1、编写网络配置文件,可以在mstar_quick_train_test.prototxt基础上进行修改,再次检查data层,需不需要mean_file和scale。注意这个scale是在减去均值以后,常用值是0.00390625,即把像素值从[0,255]归一化到[0,1]。
2、编写网络训练文件,可以在mstar_quick_solver.prototxt基础上进行修改。
3、执行训练,在train_quick.sh基础上修改。常用语句如:
从头训练:
$TOOLS/caffe train \
--solver=examples/mstar/mstar_quick_solver.prototxt
如果需要在预训练模型上微调,加上--snapshot参数:
$TOOLS/caffe train \
--solver=examples/mstar/mstar_quick_solver.prototxt\
--snapshot=examples/mstar/mstar_quick_iter_1000.solverstate
4、训练结束将生成两种文件,.caffemodel和.solverstate,第一个文件存储了网络参数,第二个文件存储了训练阶段的状态。
四、测试
1、准备若干测试图像,如190144.jpeg,保存在examples/images路径。
2、准备用于测试的模型,可以参照mstar_quick.prototxt。
3、在examples/mstar下新建测试脚本如test.sh,调用caffe/python/classify.py,添加需要的参数,参数的具体说明可以参见classify.py代码说明。另外,网上可以下载到改进版的classify.py,增加了labels_file,force_grayscale等参数。
# test.sh
python python/test.py --print_results --images_dim '128,128' \
--model_defexamples/mstar/mstar_quick.prototxt \
--pretrained_modelexamples/mstar/mstar_quick_iter_2760.caffemodel \
--input_scale0.00390625 \
--labels_filedata/mstar/label.txt \
--center_only--force_grayscale --gpu examples/images/190144.jpeg foo
4、在data/mstar下新建label.txt,建立类代号与类别名称的对应关系,以mstar数据集,为例,文件格式如下:
0 2S1
1 BMP2
2 BRDM2
3 BTR60
4 BTR70
5 D7
6 T62
7 T72
8 ZIL131
9 ZSU234
注意最后一行不要有多余的回车,否则改进版的classify.py会认为有新的类。
5、执行test.sh,可以看到分类的结果:
...
Loadingfile: examples/images/190144.jpeg
Classifying1 inputs.
Done in225.18 ms.
Predictions: [[ 1.53139865e-08 1.12762644e-08 6.07975963e-08 3.45553599e-05
4.94024803e-08 1.35941606e-04 6.95760318e-05 1.89337625e-05
1.81548385e-05 9.99722660e-01]]
python/test.py:196:FutureWarning: sort(columns=....) is deprecated, use sort_values(by=.....)
labels =labels_df.sort('synset_id')['name'].values
[('ZSU234','0.99972'), ('D7', '0.00014'), ('T62', '0.00007'), ('BTR60', '0.00003'),('T72', '0.00002')]
Savingresults into foo