【Tensorflow】使用Slim模块训练Inceptionv3

环境:Ubuntu16.04

           Tensorflow-gpu 1.12


一、获取源码

Slim模块是包含在tensorflow的另一个代码仓库内的。

git clone https://github.com/tensorflow/models/

为了和我的tensorflow版本保持一致,我切换到r1.12分支。

git checkout r1.12

二、建立数据集

把相同类别的图片放入同一个文件夹下,成为数据集的子文件夹。

然后仿造models/slim/datasets/download_and_convert_flowers.py完成训练集合验证集的分割,这个脚本是从网下下载数据集,如果是使用自己的数据集,把下载的代码去掉。

可以根据自己的数据集修改下面的几个值。

【Tensorflow】使用Slim模块训练Inceptionv3_第1张图片

然后仿造models/research/slim/datasets/flowers.py写一个自己数据集的脚本文件,主要也是修改下面这几个地方。

【Tensorflow】使用Slim模块训练Inceptionv3_第2张图片

最后不要忘记在models/research/slim/datasets/dataset_factory.py里注册自己的数据集。

然后运行脚本生成tfrecord文件。

python download_and_convert_data.py \
  --dataset_name=yourdatasetname \
  --dataset_dir=yourdatasetnamepath

就会生成

【Tensorflow】使用Slim模块训练Inceptionv3_第3张图片

这个的多个tfrecord文件,和一个labels.txt文件。

三、下载预训练模型

wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar -xvf inception_v3_2016_08_28.tar.gz
rm inception_v3_2016_08_28.tar.gz

四、训练模型

从头训练:

python train_image_classifier.py \
  --train_dir=/home/tmp/train_logs \
  --dataset_name=yourdatasetname \
  --dataset_split_name=train \
  --dataset_dir=yourdatasetpath \
  --model_name=inception_v3

 fine-tuning:

python train_image_classifier.py \
  --train_dir=/home/tmp/train_logs \
  --dataset_name=yourdatasetname \
  --dataset_split_name=train \
  --dataset_dir=yourdatasetpath \
  --model_name=inception_v3 \
  --checkpoint_path = /home/checkpoints/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits

另外还有很多可以设置或修改的参数请参考models/research/slim/train_image_classifier.py。

打印出这样的log就代表开始正常训练了。

【Tensorflow】使用Slim模块训练Inceptionv3_第4张图片

打开tensorboard可以观察loss曲线变换,及时调整训练策略。

tensorboard --logdir=/home/tmp/train_logs

五、验证模型性能

python eval_image_classifier.py \
   --alsologtostderr \
   --checkpoint_path = /home/tmp/train_logs/model.ckpt-11273 \
   --dataset_dir=yourdatasetpath \
   --dataset_split_name=validation \
   --model_name=inception_v3

【Tensorflow】使用Slim模块训练Inceptionv3_第5张图片

可以根据自己需要验证哪些性能指标修改models/research/slim/eval_image_classifier.py里面的

这个部分。

例如,数据集的类别少于5的话,可以修改Recall_5这个指标,还可以添加自己想要验证的指标。

具体支持哪些指标的验证要参考slim.metrics。

六、导出推理图

首先要保存一个GraphDef出来。

python export_inference_graph.py \
  --alsologtostderr \
  --model_name=inception_v3 \
  --output_file=/home/tmp/train_logs/inception_v3_inf_graph.pb

这一步必须这样做是因为在inceptionv3的网络定义里输入是由构建网络的时候传入的。而在训练时是直接传入的一个batch的数据,shape是[32, 299, 299, 3],因此训练得到的模型输入是固定的,这样如果直接用训练的ckpt去freeze网络得到的输入就是不对的。

它是在export_inference_graph.py里构建的网络里才传入了一个placeholder,所以必须用这个网络得到一个GraphDef得到的输入的shape才会是[?, 299, 299, 3]。

然后就冻结导出的graph。

这里需要tensorflow的源码。

git clone https://github.com/tensorflow/tensorflow/
git checkout r1.12

 先用bazel编译freeze_graph。

bazel build tensorflow/python/tools:freeze_graph

冻结。

bazel-bin/tensorflow/python/tools/freeze_graph \
  --input_graph=/home/tmp/train_logs/inception_v3_inf_graph.pb \
  --input_checkpoint=/home/tmp/train_logs/model.ckpt-14057 \
  --input_binary=true \
  --output_graph=/home/tmp/train_logs/frozen_inception_v3.pb \
  --output_node_names=InceptionV3/Predictions/Reshape_1

七、测试

可以自己写一个python脚本测试,也可以用tensorflow的label_image来测试。

bazel build tensorflow/examples/label_image:label_image

bazel-bin/tensorflow/examples/label_image/label_image \
  --image=/home/dataset/test/001.jpg \
  --input_layer=input \
  --output_layer=InceptionV3/Predictions/Reshape_1 \
  --graph=/home/tmp/train_logs/frozen_inception_v3.pb \
  --labels=/home/dataset/labels.txt \
  --input_mean = 0 \
  --input_std = 255

 

你可能感兴趣的:(深度学习,tensorflow,tensorflow,models,slim,inceptionv3,freeze_graph)