TensorFlow(十七)训练自己的图像识别模型(基于VGG16)

    下面将会使用VGG16为基础,来微调(Fine-tune)模型达到训练我们自己的数据的目的。这里将会分类一些地表的卫星图片来区分森林、水域、岩石、农田、冰川和城市区域。数据集已经上传至CSDN:https://download.csdn.net/download/viafcccy/11791071

一、数据集 

这里需要了解一下python在命令行下的参数解析

1.from __future__ import absolute_import理解   https://blog.csdn.net/viafcccy/article/details/101061413

2.argparse库   https://blog.csdn.net/viafcccy/article/details/101061661

# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
    parser.add_argument('--train-shards', default=2, type=int)
    parser.add_argument('--validation-shards', default=2, type=int)
    parser.add_argument('--num-threads', default=2, type=int)
    parser.add_argument('--dataset-name', default='satellite', type=str)
    return parser.parse_args()

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    args.tensorflow_dir = args.tensorflow_data_dir
    args.train_directory = os.path.join(args.tensorflow_dir, 'train')
    args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
    args.output_directory = args.tensorflow_dir
    args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
    if os.path.exists(args.labels_file) is False:
        logging.warning('Can\'t find label.txt. Now create it.')
        all_entries = os.listdir(args.train_directory)
        dirnames = []
        for entry in all_entries:
            if os.path.isdir(os.path.join(args.train_directory, entry)):
                dirnames.append(entry)
        with open(args.labels_file, 'w') as f:
            for dirname in dirnames:
                f.write(dirname + '\n')
    main(args)#将args作为参数传入src.tfrecord的main()函数执行

在命令行切换到当前的目录,输入

python data_convert.py -t pic/  --train-shards 2  --validation-shards 2  --num-threads 2  --dataset-name satellite

此时会发现几个问题都是由于python2和3之间的兼容问题导致的

于是我们需要打开产生问题的文件 一个是random.py 一个是tfrecord.py

1. File "E:\Deep-Learning-21-Examples\chapter_3\data_prepare\src\tfrecord.py", line 341, in _find_image_files
    random.shuffle(shuffled_index)
  File "C:\ProgramData\Anaconda3\lib\random.py", line 275, in shuffle
    x[i], x[j] = x[j], x[i]
TypeError: 'range' object does not support item assignment

将random中产生随机数的

random.shuffle(shuffled_index) 
=>random.shuffle(list(shuffled_index))

原因:是python3中range不返回数组对象,而是返回range对象。

2. File "C:\Users\Administrator\Desktop\DL\21\Deep-Learning-21-Examples-master\chapter_3\data_prepare\src\tfrecord.py", line 273, in _process_image_files
    for thread_index in xrange(len(ranges)):

NameError: name 'xrange' is not defined

将tfrecord.py中所有的xrange改为range

原因:在Python 3中,range()与xrange()合并为range( )。

剩下类似的问题还有:

tfrecord.py第160行改为  with open(filename, 'rb') as f:
tfrecord.py第94和96行修改为  colorspace = b'RGB'     image_format = b'JPEG'
tfrecord.py第104行修改为  'image/class/text': _bytes_feature(str.encode(text)),
tfrecord.py第106行修改为   'image/filename':_bytes_feature(os.path.basename(str.encode(filename))),

执行完成后产生的几个数据集文件 分别是 训练数据集 测试数据集 标签

                      TensorFlow(十七)训练自己的图像识别模型(基于VGG16)_第1张图片

 

二、使用 tensorflow slim 微调模型

下载源码

与cifar10中一样使用git命令或者直接访问https://github.com/tensorflow/models下载model文件夹 打开research文件夹 

关于更多的细节可以看https://www.cnblogs.com/bmsl/p/dongbin_bmsl_01.html

同时下载http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz的官方模型

下载完成后 

在slim文件夹中创建三个文件夹

                                      

将ckpt模型放入pretrained中

将前面生成的五个数据文件放入data

  打开slim文件夹新建一个satellite.py文件作为数据库的配置文件

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.

The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'satellite_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}

_NUM_CLASSES = 6

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)

修改数据库的配置文件 dataset_factory.py 在其中注册我们自己的数据集

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns classification image/label pairs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'satellite': satellite,
}


def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
  """Given a dataset name and a split_name returns a Dataset.

  Args:
    name: String, the name of the dataset.
    split_name: A train/test split name.
    dataset_dir: The directory where the dataset files are stored.
    file_pattern: The file pattern to use for matching the dataset source files.
    reader: The subclass of tf.ReaderBase. If left as `None`, then the default
      reader defined by each dataset is used.

  Returns:
    A `Dataset` class.

  Raises:
    ValueError: If the dataset `name` is unknown.
  """
  if name not in datasets_map:
    raise ValueError('Name of dataset unknown %s' % name)
  return datasets_map[name].get_split(
      split_name,
      dataset_dir,
      file_pattern,
      reader)

在命令行输入以下命令

python train_image_classifier.py ^
--train_dir=satellite/train_dir ::日志 cheakpoint 保存的目录 ^
--dataset_name=satellite :: 数据集名称 ^
--dataset_split_name=train ::数据集文件夹名称 ^
--dataset_dir=satellite/data ::保存训练模型的路径 ^
--model_name=inception_v3 ::使用模型的名称 ^
--checkpoint_path=satellite/pretrained/inception_v3.ckpt ::预训练模型的保存位置 ^
--checkpoint_exclude_scopes-InceptionV3/Logits,InceptionV3/AuxLogits ::预训练不恢复的两层 ^
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits ::训练中改变参数权重的层 ^
--max_number_of_steps=100000 ::最大执行次数 ^
--batch_size=32 ::一批的数据量 ^
--learning_rate=0.001 ::学习率 ^
--learning_rate_decay_type=fixed ::是否自动更新学习率 ^
--save_interval_secs=300 ::300s保存一次模型 ^
--save_surmaries_secs=2 ::2s更新一次日志 ^
--log_every_n_steps=10 ::每10步打印一次训练情况 ^
--optimizer=rmsprop ::优化器 ^
--weight_decay=0.00004 ::二次正则化超参数

 

这里值得注意的是 这里我们为了减少计算量因此选择 只训练f8 也就是输出层

要训练全部参数 只需要去掉train_scopes参数即可

你可能感兴趣的:(TensorFlow与深度学习)