深度学习之迁移学习(打造自己的图像识别模型)

 

一、概论

       迁移学习(打造自己的图像识别模型)其实就是利用已有的深度神经网络(VGG16,AlexNet , GoogLeNet等)进行简单的微调。一般有如下几种方式:

  1. 只训练全连接层。
  2. 全部网络重新训练(使用已有参数/从头开始)
  3. 只训练部分网络。

二、数据准备

    原始的图片数据和标签需要转换成tfrecord格式的文件,tfrecord, 这是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。

    1.原始图片保存在 data_prepare/pic中      深度学习之迁移学习(打造自己的图像识别模型)_第1张图片

深度学习之迁移学习(打造自己的图像识别模型)_第2张图片

   2. 在data_prepare中执行如下命令:

    python data_convert.py -t pie/ \
    --train-shards 2 \
    --validation-shards 2 \
    --num-threads 2 \
    --dataset-name satellite

三、使用 TensorFlow Slim 微调模型

1.下载 TensorFlow Slim 的源代码

   git clone http://github.com/tensorflow/models.git

找到 models/research/ 目录中的 slim 文件夹 , 这就是要用到的 TensorFlow Slim的源代码 。

2.准备数据 

    2.1  新建satellite.py

     在 datasets/ 目录下新建一个文件 satellite.py,并将 flowers.py 文件
中的内容复制到 satellite. py 中 。

    2.2 修改是 FILE PATTERN 、 SPLITS_TO_ SIZES 、NUM CLASSES

     _FILE PATTERN = ’ satellite_%s_ *. tfrecord'

     SPLITS_TO_SIZES = {'train':4800,'validation':1200}

     _NUM_CLASSES = 6

    2.3 修改 image/format 部分

       将“'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),”修改为“'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),”

    2.4 在dataset_factory. py 文件中注册satellite数据库

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'satellite': satellite,
}

3.准备训练文件夹

   3.1 在 slim 文件夹下再新建一个 satellite

   3.2 在satellite文件夹中新建一个 data 目录 ,并将 5 个转换好格式的tfrecord格式训练数据复制进去。

   3.3  在satellite文件夹中新建一个空的 train_dir 目录,用来保存训练过程中的日志和模型。

   3.4  在satellite文件夹中新建一个 pretrained 目录,在 slim 的 GitHub 页面找到 Inception V3 模型的下载地址 http://download. tensorflow.org/models/inception_v3_2016_08_28.tar.gz ,下载并解压后,会得到一个 inception_v3 .ckpt 文件,将该文件复制到 pretrained 目录下。

4.开始训练

   1.(在slim文件夹下运行)训练Logits层

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

2.训练所有层

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=10 \
  --log_every_n_steps=1 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

5.验证模型的准确率

python eval_image_classifier.py \
  --checkpoint_path=satellite/train_dir \
  --eval_dir=satellite/eval_dir \
  --dataset_name=satellite \
  --dataset_split_name=validation \
  --dataset_dir=satellite/data \
  --model_name=inception_v3

 

6.识别图片

   6.1 导出模型

python export_inference_graph.py \
  --alsologtostderr \
  --model_name=inception_v3 \
  --output_file=satellite/inception_v3_inf_graph.pb \
  --dataset_name satellite

  6.2 保存模型参数

python freeze甲graph.py \
--input_graph slim/satellite/inception_v3_inf_graph.pb \
--input_checlφoint slim/satellite/train_dir/model.ckpt-5271 \
--input_binary true \
--output node names InceptionV3/Predictions/Reshape 1 \
--output_graph slim/satellite/frozen_graph .pb

6.3 frozen_graph.pb 来对单张图片进行识别

python classify_image_inception_v3.py \
  --model_path slim/models/research/slim/satellite/frozen_graph.pb \
  --label_path data_prepare/pic/label.txt \
  --image_file test_image.jpg

6.4 结果

water (score = 2.82405)
wetland (score = 1.71531)
urban (score = 0.64461)
wood (score = 0.46510)
rock (score = -0.64199)

  识别出图片是水。

你可能感兴趣的:(深度学习实战)