【学习笔记】Tensorflow+Inception-v3训练自己的数据

导读

  喵喵的,一个大坑。本文分为吐槽和干货两部分。

一、吐槽

  大周末的,被导师扣下加班,嗨气,谁叫本狗子太弱鸡呢,看起来很简单的任务倒腾了两天还没完,不扣你扣谁?

  自己刚接到微调Inception-v3的任务时,也是觉得小意思不是,不就下载预训练模型然后finetune?

  当然,本狗子是不可能自己写代码的,毕竟弱鸡自己造轮胎从来都漏气。打开网页,眼花缭乱,选定了个看起来算比较简单的博客开始动手,嗯就这个。

  事实证明,该博客的方法不仅该说的没说不该说的瞎说还最后有巨坑。

  此处截出来进行diss,博主请假装没看到。不然,“我魏璎珞,从来脾气爆,天生不好惹...”。

  【学习笔记】Tensorflow+Inception-v3训练自己的数据_第1张图片

  好了,说说上图的事。本狗子最后调通了该博主的训练代码,证明:

  1)上图中代码导入tensorflow-hub这个包,需要事先安装,而博主文中一毛钱都没有提到。(安装tensorflow-hub是一个大坑,本狗子折腾一天最后换了台电脑才爬出来...

  2)上图中说上面链接下载Inception-v3模型,其实并不需要,亲测。原因是代码中采用的是tensorflow-hub封装的Inception-v3。

  3)代码中需要的Inception-v3模型,需要FQ下载,该下载过程是利用代码实现的,国内一般ubuntu系统(为了使用gpu训练模型方便)并不能主动FQ,因此模型无法下载,代码无法运行。(本狗子因该代码倒腾了一上午的FQ问题,然而并没有解决。最终手动下载tensorflow-hub模型并修改代码才得以解决。

  4)上图第四步,运行也是报错的。正确做法是,在代码的main函数中改默认参数,而默认参数改的并不是图上这几个。(该问题本狗子没有仔细验证,但是该脚本参数不能运行是确定的。

  在踩完上面一片大坑,用该博主代码测试时才发现更有一大坑,且该坑无法解决,只能换代码训练。出现的问题是:

  【学习笔记】Tensorflow+Inception-v3训练自己的数据_第2张图片

  遂,该博文方法终结。

  总结下来过程是,该文漏了很多东西,漏的东西里无数大坑,全坑踩完最终测试宣告该方法无解。

二、干货

  下面就直接上现在拿到的确定能跑通的代码,内容参考链接。

  1.训练数据准备

  train_data_dir/class_i/*.jpg,如 data/train/n012345678/1.jpg....

  2.训练

  直接上代码:(路径根据个人情况修改) 

   1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
   2 #
   3 # Licensed under the Apache License, Version 2.0 (the "License");
   4 # you may not use this file except in compliance with the License.
   5 # You may obtain a copy of the License at
   6 #
   7 #     http://www.apache.org/licenses/LICENSE-2.0
   8 #
   9 # Unless required by applicable law or agreed to in writing, software
  10 # distributed under the License is distributed on an "AS IS" BASIS,
  11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12 # See the License for the specific language governing permissions and
  13 # limitations under the License.
  14 # ==============================================================================
  15 r"""Simple transfer learning with Inception v3 or Mobilenet models.
  16 
  17 With support for TensorBoard.
  18 
  19 This example shows how to take a Inception v3 or Mobilenet model trained on
  20 ImageNet images, and train a new top layer that can recognize other classes of
  21 images.
  22 
  23 The top layer receives as input a 2048-dimensional vector (1001-dimensional for
  24 Mobilenet) for each image. We train a softmax layer on top of this
  25 representation. Assuming the softmax layer contains N labels, this corresponds
  26 to learning N + 2048*N (or 1001*N)  model parameters corresponding to the
  27 learned biases and weights.
  28 
  29 Here's an example, which assumes you have a folder containing class-named
  30 subfolders, each full of images for each label. The example folder flower_photos
  31 should have a structure like this:
  32 
  33 ~/flower_photos/daisy/photo1.jpg
  34 ~/flower_photos/daisy/photo2.jpg
  35 ...
  36 ~/flower_photos/rose/anotherphoto77.jpg
  37 ...
  38 ~/flower_photos/sunflower/somepicture.jpg
  39 
  40 The subfolder names are important, since they define what label is applied to
  41 each image, but the filenames themselves don't matter. Once your images are
  42 prepared, you can run the training with a command like this:
  43 
  44 
  45 bash:
  46 bazel build tensorflow/examples/image_retraining:retrain && \
  47 bazel-bin/tensorflow/examples/image_retraining/retrain \
  48     --image_dir ~/flower_photos
  49 
  50 
  51 Or, if you have a pip installation of tensorflow, `retrain.py` can be run
  52 without bazel:
  53 
  54 bash:
  55 python tensorflow/examples/image_retraining/retrain.py \
  56     --image_dir ~/flower_photos
  57 
  58 
  59 You can replace the image_dir argument with any folder containing subfolders of
  60 images. The label for each image is taken from the name of the subfolder it's
  61 in.
  62 
  63 This produces a new model file that can be loaded and run by any TensorFlow
  64 program, for example the label_image sample code.
  65 
  66 By default this script will use the high accuracy, but comparatively large and
  67 slow Inception v3 model architecture. It's recommended that you start with this
  68 to validate that you have gathered good training data, but if you want to deploy
  69 on resource-limited platforms, you can try the `--architecture` flag with a
  70 Mobilenet model. For example:
  71 
  72 bash:
  73 python tensorflow/examples/image_retraining/retrain.py \
  74     --image_dir ~/flower_photos --architecture mobilenet_1.0_224
  75 
  76 
  77 There are 32 different Mobilenet models to choose from, with a variety of file
  78 size and latency options. The first number can be '1.0', '0.75', '0.50', or
  79 '0.25' to control the size, and the second controls the input image size, either
  80 '224', '192', '160', or '128', with smaller sizes running faster. See
  81 https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
  82 for more information on Mobilenet.
  83 
  84 To use with TensorBoard:
  85 
  86 By default, this script will log summaries to /tmp/retrain_logs directory
  87 
  88 Visualize the summaries with this command:
  89 
  90 tensorboard --logdir /tmp/retrain_logs
  91 
  92 """
  93 from __future__ import absolute_import
  94 from __future__ import division
  95 from __future__ import print_function
  96 
  97 import argparse
  98 from datetime import datetime
  99 import hashlib
 100 import os.path
 101 import random
 102 import re
 103 import sys
 104 import tarfile
 105 
 106 import numpy as np
 107 from six.moves import urllib
 108 import tensorflow as tf
 109 
 110 from tensorflow.python.framework import graph_util
 111 from tensorflow.python.framework import tensor_shape
 112 from tensorflow.python.platform import gfile
 113 from tensorflow.python.util import compat
 114 
 115 FLAGS = None
 116 
 117 # These are all parameters that are tied to the particular model architecture
 118 # we're using for Inception v3. These include things like tensor names and their
 119 # sizes. If you want to adapt this script to work with another model, you will
 120 # need to update these to reflect the values in the network you're using.
 121 MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1  # ~134M
 122 
 123 
 124 def create_image_lists(image_dir, testing_percentage, validation_percentage):
 125   """Builds a list of training images from the file system.
 126 
 127   Analyzes the sub folders in the image directory, splits them into stable
 128   training, testing, and validation sets, and returns a data structure
 129   describing the lists of images for each label and their paths.
 130 
 131   Args:
 132     image_dir: String path to a folder containing subfolders of images.
 133     testing_percentage: Integer percentage of the images to reserve for tests.
 134     validation_percentage: Integer percentage of images reserved for validation.
 135 
 136   Returns:
 137     A dictionary containing an entry for each label subfolder, with images split
 138     into training, testing, and validation sets within each label.
 139   """
 140   if not gfile.Exists(image_dir):
 141     tf.logging.error("Image directory '" + image_dir + "' not found.")
 142     return None
 143   result = {}
 144   sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
 145   # The root directory comes first, so skip it.
 146   is_root_dir = True
 147   for sub_dir in sub_dirs:
 148     if is_root_dir:
 149       is_root_dir = False
 150       continue
 151     extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
 152     file_list = []
 153     dir_name = os.path.basename(sub_dir)
 154     if dir_name == image_dir:
 155       continue
 156     tf.logging.info("Looking for images in '" + dir_name + "'")
 157     for extension in extensions:
 158       file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
 159       file_list.extend(gfile.Glob(file_glob))
 160     if not file_list:
 161       tf.logging.warning('No files found')
 162       continue
 163     if len(file_list) < 20:
 164       tf.logging.warning(
 165           'WARNING: Folder has less than 20 images, which may cause issues.')
 166     elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
 167       tf.logging.warning(
 168           'WARNING: Folder {} has more than {} images. Some images will '
 169           'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
 170     label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
 171     training_images = []
 172     testing_images = []
 173     validation_images = []
 174     for file_name in file_list:
 175       base_name = os.path.basename(file_name)
 176       # We want to ignore anything after '_nohash_' in the file name when
 177       # deciding which set to put an image in, the data set creator has a way of
 178       # grouping photos that are close variations of each other. For example
 179       # this is used in the plant disease data set to group multiple pictures of
 180       # the same leaf.
 181       hash_name = re.sub(r'_nohash_.*$', '', file_name)
 182       # This looks a bit magical, but we need to decide whether this file should
 183       # go into the training, testing, or validation sets, and we want to keep
 184       # existing files in the same set even if more files are subsequently
 185       # added.
 186       # To do that, we need a stable way of deciding based on just the file name
 187       # itself, so we do a hash of that and then use that to generate a
 188       # probability value that we use to assign it.
 189       hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
 190       percentage_hash = ((int(hash_name_hashed, 16) %
 191                           (MAX_NUM_IMAGES_PER_CLASS + 1)) *
 192                          (100.0 / MAX_NUM_IMAGES_PER_CLASS))
 193       if percentage_hash < validation_percentage:
 194         validation_images.append(base_name)
 195       elif percentage_hash < (testing_percentage + validation_percentage):
 196         testing_images.append(base_name)
 197       else:
 198         training_images.append(base_name)
 199     result[label_name] = {
 200         'dir': dir_name,
 201         'training': training_images,
 202         'testing': testing_images,
 203         'validation': validation_images,
 204     }
 205   return result
 206 
 207 
 208 def get_image_path(image_lists, label_name, index, image_dir, category):
 209   """"Returns a path to an image for a label at the given index.
 210 
 211   Args:
 212     image_lists: Dictionary of training images for each label.
 213     label_name: Label string we want to get an image for.
 214     index: Int offset of the image we want. This will be moduloed by the
 215     available number of images for the label, so it can be arbitrarily large.
 216     image_dir: Root folder string of the subfolders containing the training
 217     images.
 218     category: Name string of set to pull images from - training, testing, or
 219     validation.
 220 
 221   Returns:
 222     File system path string to an image that meets the requested parameters.
 223 
 224   """
 225   if label_name not in image_lists:
 226     tf.logging.fatal('Label does not exist %s.', label_name)
 227   label_lists = image_lists[label_name]
 228   if category not in label_lists:
 229     tf.logging.fatal('Category does not exist %s.', category)
 230   category_list = label_lists[category]
 231   if not category_list:
 232     tf.logging.fatal('Label %s has no images in the category %s.',
 233                      label_name, category)
 234   mod_index = index % len(category_list)
 235   base_name = category_list[mod_index]
 236   sub_dir = label_lists['dir']
 237   full_path = os.path.join(image_dir, sub_dir, base_name)
 238   return full_path
 239 
 240 
 241 def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
 242                         category, architecture):
 243   """"Returns a path to a bottleneck file for a label at the given index.
 244 
 245   Args:
 246     image_lists: Dictionary of training images for each label.
 247     label_name: Label string we want to get an image for.
 248     index: Integer offset of the image we want. This will be moduloed by the
 249     available number of images for the label, so it can be arbitrarily large.
 250     bottleneck_dir: Folder string holding cached files of bottleneck values.
 251     category: Name string of set to pull images from - training, testing, or
 252     validation.
 253     architecture: The name of the model architecture.
 254 
 255   Returns:
 256     File system path string to an image that meets the requested parameters.
 257   """
 258   return get_image_path(image_lists, label_name, index, bottleneck_dir,
 259                         category) + '_' + architecture + '.txt'
 260 
 261 
 262 def create_model_graph(model_info):
 263   """"Creates a graph from saved GraphDef file and returns a Graph object.
 264 
 265   Args:
 266     model_info: Dictionary containing information about the model architecture.
 267 
 268   Returns:
 269     Graph holding the trained Inception network, and various tensors we'll be
 270     manipulating.
 271   """
 272   with tf.Graph().as_default() as graph:
 273     model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
 274     with gfile.FastGFile(model_path, 'rb') as f:
 275       graph_def = tf.GraphDef()
 276       graph_def.ParseFromString(f.read())
 277       bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
 278           graph_def,
 279           name='',
 280           return_elements=[
 281               model_info['bottleneck_tensor_name'],
 282               model_info['resized_input_tensor_name'],
 283           ]))
 284   return graph, bottleneck_tensor, resized_input_tensor
 285 
 286 
 287 def run_bottleneck_on_image(sess, image_data, image_data_tensor,
 288                             decoded_image_tensor, resized_input_tensor,
 289                             bottleneck_tensor):
 290   """Runs inference on an image to extract the 'bottleneck' summary layer.
 291 
 292   Args:
 293     sess: Current active TensorFlow Session.
 294     image_data: String of raw JPEG data.
 295     image_data_tensor: Input data layer in the graph.
 296     decoded_image_tensor: Output of initial image resizing and preprocessing.
 297     resized_input_tensor: The input node of the recognition graph.
 298     bottleneck_tensor: Layer before the final softmax.
 299 
 300   Returns:
 301     Numpy array of bottleneck values.
 302   """
 303   # First decode the JPEG image, resize it, and rescale the pixel values.
 304   resized_input_values = sess.run(decoded_image_tensor,
 305                                   {image_data_tensor: image_data})
 306   # Then run it through the recognition network.
 307   bottleneck_values = sess.run(bottleneck_tensor,
 308                                {resized_input_tensor: resized_input_values})
 309   bottleneck_values = np.squeeze(bottleneck_values)
 310   return bottleneck_values
 311 
 312 
 313 def maybe_download_and_extract(data_url):
 314   """Download and extract model tar file.
 315 
 316   If the pretrained model we're using doesn't already exist, this function
 317   downloads it from the TensorFlow.org website and unpacks it into a directory.
 318 
 319   Args:
 320     data_url: Web location of the tar file containing the pretrained model.
 321   """
 322   dest_directory = FLAGS.model_dir
 323   if not os.path.exists(dest_directory):
 324     os.makedirs(dest_directory)
 325   filename = data_url.split('/')[-1]
 326   filepath = os.path.join(dest_directory, filename)
 327   if not os.path.exists(filepath):
 328 
 329     def _progress(count, block_size, total_size):
 330       sys.stdout.write('\r>> Downloading %s %.1f%%' %
 331                        (filename,
 332                         float(count * block_size) / float(total_size) * 100.0))
 333       sys.stdout.flush()
 334 
 335     filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
 336     print()
 337     statinfo = os.stat(filepath)
 338     tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
 339                     'bytes.')
 340   tarfile.open(filepath, 'r:gz').extractall(dest_directory)
 341 
 342 
 343 def ensure_dir_exists(dir_name):
 344   """Makes sure the folder exists on disk.
 345 
 346   Args:
 347     dir_name: Path string to the folder we want to create.
 348   """
 349   if not os.path.exists(dir_name):
 350     os.makedirs(dir_name)
 351 
 352 
 353 bottleneck_path_2_bottleneck_values = {}
 354 
 355 
 356 def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
 357                            image_dir, category, sess, jpeg_data_tensor,
 358                            decoded_image_tensor, resized_input_tensor,
 359                            bottleneck_tensor):
 360   """Create a single bottleneck file."""
 361   tf.logging.info('Creating bottleneck at ' + bottleneck_path)
 362   image_path = get_image_path(image_lists, label_name, index,
 363                               image_dir, category)
 364   if not gfile.Exists(image_path):
 365     tf.logging.fatal('File does not exist %s', image_path)
 366   image_data = gfile.FastGFile(image_path, 'rb').read()
 367   try:
 368     bottleneck_values = run_bottleneck_on_image(
 369         sess, image_data, jpeg_data_tensor, decoded_image_tensor,
 370         resized_input_tensor, bottleneck_tensor)
 371   except Exception as e:
 372     raise RuntimeError('Error during processing file %s (%s)' % (image_path,
 373                                                                  str(e)))
 374   bottleneck_string = ','.join(str(x) for x in bottleneck_values)
 375   with open(bottleneck_path, 'w') as bottleneck_file:
 376     bottleneck_file.write(bottleneck_string)
 377 
 378 
 379 def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
 380                              category, bottleneck_dir, jpeg_data_tensor,
 381                              decoded_image_tensor, resized_input_tensor,
 382                              bottleneck_tensor, architecture):
 383   """Retrieves or calculates bottleneck values for an image.
 384 
 385   If a cached version of the bottleneck data exists on-disk, return that,
 386   otherwise calculate the data and save it to disk for future use.
 387 
 388   Args:
 389     sess: The current active TensorFlow Session.
 390     image_lists: Dictionary of training images for each label.
 391     label_name: Label string we want to get an image for.
 392     index: Integer offset of the image we want. This will be modulo-ed by the
 393     available number of images for the label, so it can be arbitrarily large.
 394     image_dir: Root folder string of the subfolders containing the training
 395     images.
 396     category: Name string of which set to pull images from - training, testing,
 397     or validation.
 398     bottleneck_dir: Folder string holding cached files of bottleneck values.
 399     jpeg_data_tensor: The tensor to feed loaded jpeg data into.
 400     decoded_image_tensor: The output of decoding and resizing the image.
 401     resized_input_tensor: The input node of the recognition graph.
 402     bottleneck_tensor: The output tensor for the bottleneck values.
 403     architecture: The name of the model architecture.
 404 
 405   Returns:
 406     Numpy array of values produced by the bottleneck layer for the image.
 407   """
 408   label_lists = image_lists[label_name]
 409   sub_dir = label_lists['dir']
 410   sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
 411   ensure_dir_exists(sub_dir_path)
 412   bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
 413                                         bottleneck_dir, category, architecture)
 414   if not os.path.exists(bottleneck_path):
 415     create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
 416                            image_dir, category, sess, jpeg_data_tensor,
 417                            decoded_image_tensor, resized_input_tensor,
 418                            bottleneck_tensor)
 419   with open(bottleneck_path, 'r') as bottleneck_file:
 420     bottleneck_string = bottleneck_file.read()
 421   did_hit_error = False
 422   try:
 423     bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
 424   except ValueError:
 425     tf.logging.warning('Invalid float found, recreating bottleneck')
 426     did_hit_error = True
 427   if did_hit_error:
 428     create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
 429                            image_dir, category, sess, jpeg_data_tensor,
 430                            decoded_image_tensor, resized_input_tensor,
 431                            bottleneck_tensor)
 432     with open(bottleneck_path, 'r') as bottleneck_file:
 433       bottleneck_string = bottleneck_file.read()
 434     # Allow exceptions to propagate here, since they shouldn't happen after a
 435     # fresh creation
 436     bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
 437   return bottleneck_values
 438 
 439 
 440 def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
 441                       jpeg_data_tensor, decoded_image_tensor,
 442                       resized_input_tensor, bottleneck_tensor, architecture):
 443   """Ensures all the training, testing, and validation bottlenecks are cached.
 444 
 445   Because we're likely to read the same image multiple times (if there are no
 446   distortions applied during training) it can speed things up a lot if we
 447   calculate the bottleneck layer values once for each image during
 448   preprocessing, and then just read those cached values repeatedly during
 449   training. Here we go through all the images we've found, calculate those
 450   values, and save them off.
 451 
 452   Args:
 453     sess: The current active TensorFlow Session.
 454     image_lists: Dictionary of training images for each label.
 455     image_dir: Root folder string of the subfolders containing the training
 456     images.
 457     bottleneck_dir: Folder string holding cached files of bottleneck values.
 458     jpeg_data_tensor: Input tensor for jpeg data from file.
 459     decoded_image_tensor: The output of decoding and resizing the image.
 460     resized_input_tensor: The input node of the recognition graph.
 461     bottleneck_tensor: The penultimate output layer of the graph.
 462     architecture: The name of the model architecture.
 463 
 464   Returns:
 465     Nothing.
 466   """
 467   how_many_bottlenecks = 0
 468   ensure_dir_exists(bottleneck_dir)
 469   for label_name, label_lists in image_lists.items():
 470     for category in ['training', 'testing', 'validation']:
 471       category_list = label_lists[category]
 472       for index, unused_base_name in enumerate(category_list):
 473         get_or_create_bottleneck(
 474             sess, image_lists, label_name, index, image_dir, category,
 475             bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
 476             resized_input_tensor, bottleneck_tensor, architecture)
 477 
 478         how_many_bottlenecks += 1
 479         if how_many_bottlenecks % 100 == 0:
 480           tf.logging.info(
 481               str(how_many_bottlenecks) + ' bottleneck files created.')
 482 
 483 
 484 def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
 485                                   bottleneck_dir, image_dir, jpeg_data_tensor,
 486                                   decoded_image_tensor, resized_input_tensor,
 487                                   bottleneck_tensor, architecture):
 488   """Retrieves bottleneck values for cached images.
 489 
 490   If no distortions are being applied, this function can retrieve the cached
 491   bottleneck values directly from disk for images. It picks a random set of
 492   images from the specified category.
 493 
 494   Args:
 495     sess: Current TensorFlow Session.
 496     image_lists: Dictionary of training images for each label.
 497     how_many: If positive, a random sample of this size will be chosen.
 498     If negative, all bottlenecks will be retrieved.
 499     category: Name string of which set to pull from - training, testing, or
 500     validation.
 501     bottleneck_dir: Folder string holding cached files of bottleneck values.
 502     image_dir: Root folder string of the subfolders containing the training
 503     images.
 504     jpeg_data_tensor: The layer to feed jpeg image data into.
 505     decoded_image_tensor: The output of decoding and resizing the image.
 506     resized_input_tensor: The input node of the recognition graph.
 507     bottleneck_tensor: The bottleneck output layer of the CNN graph.
 508     architecture: The name of the model architecture.
 509 
 510   Returns:
 511     List of bottleneck arrays, their corresponding ground truths, and the
 512     relevant filenames.
 513   """
 514   class_count = len(image_lists.keys())
 515   bottlenecks = []
 516   ground_truths = []
 517   filenames = []
 518   if how_many >= 0:
 519     # Retrieve a random sample of bottlenecks.
 520     for unused_i in range(how_many):
 521       label_index = random.randrange(class_count)
 522       label_name = list(image_lists.keys())[label_index]
 523       image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
 524       image_name = get_image_path(image_lists, label_name, image_index,
 525                                   image_dir, category)
 526       bottleneck = get_or_create_bottleneck(
 527           sess, image_lists, label_name, image_index, image_dir, category,
 528           bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
 529           resized_input_tensor, bottleneck_tensor, architecture)
 530       ground_truth = np.zeros(class_count, dtype=np.float32)
 531       ground_truth[label_index] = 1.0
 532       bottlenecks.append(bottleneck)
 533       ground_truths.append(ground_truth)
 534       filenames.append(image_name)
 535   else:
 536     # Retrieve all bottlenecks.
 537     for label_index, label_name in enumerate(image_lists.keys()):
 538       for image_index, image_name in enumerate(
 539           image_lists[label_name][category]):
 540         image_name = get_image_path(image_lists, label_name, image_index,
 541                                     image_dir, category)
 542         bottleneck = get_or_create_bottleneck(
 543             sess, image_lists, label_name, image_index, image_dir, category,
 544             bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
 545             resized_input_tensor, bottleneck_tensor, architecture)
 546         ground_truth = np.zeros(class_count, dtype=np.float32)
 547         ground_truth[label_index] = 1.0
 548         bottlenecks.append(bottleneck)
 549         ground_truths.append(ground_truth)
 550         filenames.append(image_name)
 551   return bottlenecks, ground_truths, filenames
 552 
 553 
 554 def get_random_distorted_bottlenecks(
 555     sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
 556     distorted_image, resized_input_tensor, bottleneck_tensor):
 557   """Retrieves bottleneck values for training images, after distortions.
 558 
 559   If we're training with distortions like crops, scales, or flips, we have to
 560   recalculate the full model for every image, and so we can't use cached
 561   bottleneck values. Instead we find random images for the requested category,
 562   run them through the distortion graph, and then the full graph to get the
 563   bottleneck results for each.
 564 
 565   Args:
 566     sess: Current TensorFlow Session.
 567     image_lists: Dictionary of training images for each label.
 568     how_many: The integer number of bottleneck values to return.
 569     category: Name string of which set of images to fetch - training, testing,
 570     or validation.
 571     image_dir: Root folder string of the subfolders containing the training
 572     images.
 573     input_jpeg_tensor: The input layer we feed the image data to.
 574     distorted_image: The output node of the distortion graph.
 575     resized_input_tensor: The input node of the recognition graph.
 576     bottleneck_tensor: The bottleneck output layer of the CNN graph.
 577 
 578   Returns:
 579     List of bottleneck arrays and their corresponding ground truths.
 580   """
 581   class_count = len(image_lists.keys())
 582   bottlenecks = []
 583   ground_truths = []
 584   for unused_i in range(how_many):
 585     label_index = random.randrange(class_count)
 586     label_name = list(image_lists.keys())[label_index]
 587     image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
 588     image_path = get_image_path(image_lists, label_name, image_index, image_dir,
 589                                 category)
 590     if not gfile.Exists(image_path):
 591       tf.logging.fatal('File does not exist %s', image_path)
 592     jpeg_data = gfile.FastGFile(image_path, 'rb').read()
 593     # Note that we materialize the distorted_image_data as a numpy array before
 594     # sending running inference on the image. This involves 2 memory copies and
 595     # might be optimized in other implementations.
 596     distorted_image_data = sess.run(distorted_image,
 597                                     {input_jpeg_tensor: jpeg_data})
 598     bottleneck_values = sess.run(bottleneck_tensor,
 599                                  {resized_input_tensor: distorted_image_data})
 600     bottleneck_values = np.squeeze(bottleneck_values)
 601     ground_truth = np.zeros(class_count, dtype=np.float32)
 602     ground_truth[label_index] = 1.0
 603     bottlenecks.append(bottleneck_values)
 604     ground_truths.append(ground_truth)
 605   return bottlenecks, ground_truths
 606 
 607 
 608 def should_distort_images(flip_left_right, random_crop, random_scale,
 609                           random_brightness):
 610   """Whether any distortions are enabled, from the input flags.
 611 
 612   Args:
 613     flip_left_right: Boolean whether to randomly mirror images horizontally.
 614     random_crop: Integer percentage setting the total margin used around the
 615     crop box.
 616     random_scale: Integer percentage of how much to vary the scale by.
 617     random_brightness: Integer range to randomly multiply the pixel values by.
 618 
 619   Returns:
 620     Boolean value indicating whether any distortions should be applied.
 621   """
 622   return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
 623           (random_brightness != 0))
 624 
 625 
 626 def add_input_distortions(flip_left_right, random_crop, random_scale,
 627                           random_brightness, input_width, input_height,
 628                           input_depth, input_mean, input_std):
 629   """Creates the operations to apply the specified distortions.
 630 
 631   During training it can help to improve the results if we run the images
 632   through simple distortions like crops, scales, and flips. These reflect the
 633   kind of variations we expect in the real world, and so can help train the
 634   model to cope with natural data more effectively. Here we take the supplied
 635   parameters and construct a network of operations to apply them to an image.
 636 
 637   Cropping
 638   ~~~~~~~~
 639 
 640   Cropping is done by placing a bounding box at a random position in the full
 641   image. The cropping parameter controls the size of that box relative to the
 642   input image. If it's zero, then the box is the same size as the input and no
 643   cropping is performed. If the value is 50%, then the crop box will be half the
 644   width and height of the input. In a diagram it looks like this:
 645 
 646   <       width         >
 647   +---------------------+
 648   |                     |
 649   |   width - crop%     |
 650   |    <      >         |
 651   |    +------+         |
 652   |    |      |         |
 653   |    |      |         |
 654   |    |      |         |
 655   |    +------+         |
 656   |                     |
 657   |                     |
 658   +---------------------+
 659 
 660   Scaling
 661   ~~~~~~~
 662 
 663   Scaling is a lot like cropping, except that the bounding box is always
 664   centered and its size varies randomly within the given range. For example if
 665   the scale percentage is zero, then the bounding box is the same size as the
 666   input and no scaling is applied. If it's 50%, then the bounding box will be in
 667   a random range between half the width and height and full size.
 668 
 669   Args:
 670     flip_left_right: Boolean whether to randomly mirror images horizontally.
 671     random_crop: Integer percentage setting the total margin used around the
 672     crop box.
 673     random_scale: Integer percentage of how much to vary the scale by.
 674     random_brightness: Integer range to randomly multiply the pixel values by.
 675     graph.
 676     input_width: Horizontal size of expected input image to model.
 677     input_height: Vertical size of expected input image to model.
 678     input_depth: How many channels the expected input image should have.
 679     input_mean: Pixel value that should be zero in the image for the graph.
 680     input_std: How much to divide the pixel values by before recognition.
 681 
 682   Returns:
 683     The jpeg input layer and the distorted result tensor.
 684   """
 685 
 686   jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
 687   decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
 688   decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
 689   decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
 690   margin_scale = 1.0 + (random_crop / 100.0)
 691   resize_scale = 1.0 + (random_scale / 100.0)
 692   margin_scale_value = tf.constant(margin_scale)
 693   resize_scale_value = tf.random_uniform(tensor_shape.scalar(),
 694                                          minval=1.0,
 695                                          maxval=resize_scale)
 696   scale_value = tf.multiply(margin_scale_value, resize_scale_value)
 697   precrop_width = tf.multiply(scale_value, input_width)
 698   precrop_height = tf.multiply(scale_value, input_height)
 699   precrop_shape = tf.stack([precrop_height, precrop_width])
 700   precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
 701   precropped_image = tf.image.resize_bilinear(decoded_image_4d,
 702                                               precrop_shape_as_int)
 703   precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
 704   cropped_image = tf.random_crop(precropped_image_3d,
 705                                  [input_height, input_width, input_depth])
 706   if flip_left_right:
 707     flipped_image = tf.image.random_flip_left_right(cropped_image)
 708   else:
 709     flipped_image = cropped_image
 710   brightness_min = 1.0 - (random_brightness / 100.0)
 711   brightness_max = 1.0 + (random_brightness / 100.0)
 712   brightness_value = tf.random_uniform(tensor_shape.scalar(),
 713                                        minval=brightness_min,
 714                                        maxval=brightness_max)
 715   brightened_image = tf.multiply(flipped_image, brightness_value)
 716   offset_image = tf.subtract(brightened_image, input_mean)
 717   mul_image = tf.multiply(offset_image, 1.0 / input_std)
 718   distort_result = tf.expand_dims(mul_image, 0, name='DistortResult')
 719   return jpeg_data, distort_result
 720 
 721 
 722 def variable_summaries(var):
 723   """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
 724   with tf.name_scope('summaries'):
 725     mean = tf.reduce_mean(var)
 726     tf.summary.scalar('mean', mean)
 727     with tf.name_scope('stddev'):
 728       stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
 729     tf.summary.scalar('stddev', stddev)
 730     tf.summary.scalar('max', tf.reduce_max(var))
 731     tf.summary.scalar('min', tf.reduce_min(var))
 732     tf.summary.histogram('histogram', var)
 733 
 734 
 735 def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
 736                            bottleneck_tensor_size):
 737   """Adds a new softmax and fully-connected layer for training.
 738 
 739   We need to retrain the top layer to identify our new classes, so this function
 740   adds the right operations to the graph, along with some variables to hold the
 741   weights, and then sets up all the gradients for the backward pass.
 742 
 743   The set up for the softmax and fully-connected layers is based on:
 744   https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
 745 
 746   Args:
 747     class_count: Integer of how many categories of things we're trying to
 748     recognize.
 749     final_tensor_name: Name string for the new final node that produces results.
 750     bottleneck_tensor: The output of the main CNN graph.
 751     bottleneck_tensor_size: How many entries in the bottleneck vector.
 752 
 753   Returns:
 754     The tensors for the training and cross entropy results, and tensors for the
 755     bottleneck input and ground truth input.
 756   """
 757   with tf.name_scope('input'):
 758     bottleneck_input = tf.placeholder_with_default(
 759         bottleneck_tensor,
 760         shape=[None, bottleneck_tensor_size],
 761         name='BottleneckInputPlaceholder')
 762 
 763     ground_truth_input = tf.placeholder(tf.float32,
 764                                         [None, class_count],
 765                                         name='GroundTruthInput')
 766 
 767   # Organizing the following ops as `final_training_ops` so they're easier
 768   # to see in TensorBoard
 769   layer_name = 'final_training_ops'
 770   with tf.name_scope(layer_name):
 771     with tf.name_scope('weights'):
 772       initial_value = tf.truncated_normal(
 773           [bottleneck_tensor_size, class_count], stddev=0.001)
 774 
 775       layer_weights = tf.Variable(initial_value, name='final_weights')
 776 
 777       variable_summaries(layer_weights)
 778     with tf.name_scope('biases'):
 779       layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
 780       variable_summaries(layer_biases)
 781     with tf.name_scope('Wx_plus_b'):
 782       logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
 783       tf.summary.histogram('pre_activations', logits)
 784 
 785   final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
 786   tf.summary.histogram('activations', final_tensor)
 787 
 788   with tf.name_scope('cross_entropy'):
 789     cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
 790         labels=ground_truth_input, logits=logits)
 791     with tf.name_scope('total'):
 792       cross_entropy_mean = tf.reduce_mean(cross_entropy)
 793   tf.summary.scalar('cross_entropy', cross_entropy_mean)
 794 
 795   with tf.name_scope('train'):
 796     optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
 797     train_step = optimizer.minimize(cross_entropy_mean)
 798 
 799   return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
 800           final_tensor)
 801 
 802 
 803 def add_evaluation_step(result_tensor, ground_truth_tensor):
 804   """Inserts the operations we need to evaluate the accuracy of our results.
 805 
 806   Args:
 807     result_tensor: The new final node that produces results.
 808     ground_truth_tensor: The node we feed ground truth data
 809     into.
 810 
 811   Returns:
 812     Tuple of (evaluation step, prediction).
 813   """
 814   with tf.name_scope('accuracy'):
 815     with tf.name_scope('correct_prediction'):
 816       prediction = tf.argmax(result_tensor, 1)
 817       correct_prediction = tf.equal(
 818           prediction, tf.argmax(ground_truth_tensor, 1))
 819     with tf.name_scope('accuracy'):
 820       evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 821   tf.summary.scalar('accuracy', evaluation_step)
 822   return evaluation_step, prediction
 823 
 824 
 825 def save_graph_to_file(sess, graph, graph_file_name):
 826   output_graph_def = graph_util.convert_variables_to_constants(
 827       sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
 828   with gfile.FastGFile(graph_file_name, 'wb') as f:
 829     f.write(output_graph_def.SerializeToString())
 830   return
 831 
 832 
 833 def prepare_file_system():
 834   # Setup the directory we'll write summaries to for TensorBoard
 835   if tf.gfile.Exists(FLAGS.summaries_dir):
 836     tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
 837   tf.gfile.MakeDirs(FLAGS.summaries_dir)
 838   if FLAGS.intermediate_store_frequency > 0:
 839     ensure_dir_exists(FLAGS.intermediate_output_graphs_dir)
 840   return
 841 
 842 
 843 def create_model_info(architecture):
 844   """Given the name of a model architecture, returns information about it.
 845 
 846   There are different base image recognition pretrained models that can be
 847   retrained using transfer learning, and this function translates from the name
 848   of a model to the attributes that are needed to download and train with it.
 849 
 850   Args:
 851     architecture: Name of a model architecture.
 852 
 853   Returns:
 854     Dictionary of information about the model, or None if the name isn't
 855     recognized
 856 
 857   Raises:
 858     ValueError: If architecture name is unknown.
 859   """
 860   architecture = architecture.lower()
 861   if architecture == 'inception_v3':
 862     # pylint: disable=line-too-long
 863     data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
 864     # pylint: enable=line-too-long
 865     bottleneck_tensor_name = 'pool_3/_reshape:0'
 866     bottleneck_tensor_size = 2048
 867     input_width = 299
 868     input_height = 299
 869     input_depth = 3
 870     resized_input_tensor_name = 'Mul:0'
 871     model_file_name = 'classify_image_graph_def.pb'
 872     input_mean = 128
 873     input_std = 128
 874   elif architecture.startswith('mobilenet_'):
 875     parts = architecture.split('_')
 876     if len(parts) != 3 and len(parts) != 4:
 877       tf.logging.error("Couldn't understand architecture name '%s'",
 878                        architecture)
 879       return None
 880     version_string = parts[1]
 881     if (version_string != '1.0' and version_string != '0.75' and
 882         version_string != '0.50' and version_string != '0.25'):
 883       tf.logging.error(
 884           """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
 885   but found '%s' for architecture '%s'""",
 886           version_string, architecture)
 887       return None
 888     size_string = parts[2]
 889     if (size_string != '224' and size_string != '192' and
 890         size_string != '160' and size_string != '128'):
 891       tf.logging.error(
 892           """The Mobilenet input size should be '224', '192', '160', or '128',
 893  but found '%s' for architecture '%s'""",
 894           size_string, architecture)
 895       return None
 896     if len(parts) == 3:
 897       is_quantized = False
 898     else:
 899       if parts[3] != 'quantized':
 900         tf.logging.error(
 901             "Couldn't understand architecture suffix '%s' for '%s'", parts[3],
 902             architecture)
 903         return None
 904       is_quantized = True
 905     data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
 906     data_url += version_string + '_' + size_string + '_frozen.tgz'
 907     bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
 908     bottleneck_tensor_size = 1001
 909     input_width = int(size_string)
 910     input_height = int(size_string)
 911     input_depth = 3
 912     resized_input_tensor_name = 'input:0'
 913     if is_quantized:
 914       model_base_name = 'quantized_graph.pb'
 915     else:
 916       model_base_name = 'frozen_graph.pb'
 917     model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
 918     model_file_name = os.path.join(model_dir_name, model_base_name)
 919     input_mean = 127.5
 920     input_std = 127.5
 921   else:
 922     tf.logging.error("Couldn't understand architecture name '%s'", architecture)
 923     raise ValueError('Unknown architecture', architecture)
 924 
 925   return {
 926       'data_url': data_url,
 927       'bottleneck_tensor_name': bottleneck_tensor_name,
 928       'bottleneck_tensor_size': bottleneck_tensor_size,
 929       'input_width': input_width,
 930       'input_height': input_height,
 931       'input_depth': input_depth,
 932       'resized_input_tensor_name': resized_input_tensor_name,
 933       'model_file_name': model_file_name,
 934       'input_mean': input_mean,
 935       'input_std': input_std,
 936   }
 937 
 938 
 939 def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
 940                       input_std):
 941   """Adds operations that perform JPEG decoding and resizing to the graph..
 942 
 943   Args:
 944     input_width: Desired width of the image fed into the recognizer graph.
 945     input_height: Desired width of the image fed into the recognizer graph.
 946     input_depth: Desired channels of the image fed into the recognizer graph.
 947     input_mean: Pixel value that should be zero in the image for the graph.
 948     input_std: How much to divide the pixel values by before recognition.
 949 
 950   Returns:
 951     Tensors for the node to feed JPEG data into, and the output of the
 952       preprocessing steps.
 953   """
 954   jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput')
 955   decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
 956   decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
 957   decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
 958   resize_shape = tf.stack([input_height, input_width])
 959   resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
 960   resized_image = tf.image.resize_bilinear(decoded_image_4d,
 961                                            resize_shape_as_int)
 962   offset_image = tf.subtract(resized_image, input_mean)
 963   mul_image = tf.multiply(offset_image, 1.0 / input_std)
 964   return jpeg_data, mul_image
 965 
 966 
 967 def main(_):
 968   # Needed to make sure the logging output is visible.
 969   # See https://github.com/tensorflow/tensorflow/issues/3047
 970   tf.logging.set_verbosity(tf.logging.INFO)
 971 
 972   # Prepare necessary directories that can be used during training
 973   prepare_file_system()
 974 
 975   # Gather information about the model architecture we'll be using.
 976   model_info = create_model_info(FLAGS.architecture)
 977   if not model_info:
 978     tf.logging.error('Did not recognize architecture flag')
 979     return -1
 980 
 981   # Set up the pre-trained graph.
 982   maybe_download_and_extract(model_info['data_url'])
 983   graph, bottleneck_tensor, resized_image_tensor = (
 984       create_model_graph(model_info))
 985 
 986   # Look at the folder structure, and create lists of all the images.
 987   image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
 988                                    FLAGS.validation_percentage)
 989   class_count = len(image_lists.keys())
 990   if class_count == 0:
 991     tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir)
 992     return -1
 993   if class_count == 1:
 994     tf.logging.error('Only one valid folder of images found at ' +
 995                      FLAGS.image_dir +
 996                      ' - multiple classes are needed for classification.')
 997     return -1
 998 
 999   # See if the command-line flags mean we're applying any distortions.
1000   do_distort_images = should_distort_images(
1001       FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1002       FLAGS.random_brightness)
1003 
1004   with tf.Session(graph=graph) as sess:
1005     # Set up the image decoding sub-graph.
1006     jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(
1007         model_info['input_width'], model_info['input_height'],
1008         model_info['input_depth'], model_info['input_mean'],
1009         model_info['input_std'])
1010 
1011     if do_distort_images:
1012       # We will be applying distortions, so setup the operations we'll need.
1013       (distorted_jpeg_data_tensor,
1014        distorted_image_tensor) = add_input_distortions(
1015            FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1016            FLAGS.random_brightness, model_info['input_width'],
1017            model_info['input_height'], model_info['input_depth'],
1018            model_info['input_mean'], model_info['input_std'])
1019     else:
1020       # We'll make sure we've calculated the 'bottleneck' image summaries and
1021       # cached them on disk.
1022       cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
1023                         FLAGS.bottleneck_dir, jpeg_data_tensor,
1024                         decoded_image_tensor, resized_image_tensor,
1025                         bottleneck_tensor, FLAGS.architecture)
1026 
1027     # Add the new layer that we'll be training.
1028     (train_step, cross_entropy, bottleneck_input, ground_truth_input,
1029      final_tensor) = add_final_training_ops(
1030          len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
1031          model_info['bottleneck_tensor_size'])
1032 
1033     # Create the operations we need to evaluate the accuracy of our new layer.
1034     evaluation_step, prediction = add_evaluation_step(
1035         final_tensor, ground_truth_input)
1036 
1037     # Merge all the summaries and write them out to the summaries_dir
1038     merged = tf.summary.merge_all()
1039     train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
1040                                          sess.graph)
1041 
1042     validation_writer = tf.summary.FileWriter(
1043         FLAGS.summaries_dir + '/validation')
1044 
1045     # Set up all our weights to their initial default values.
1046     init = tf.global_variables_initializer()
1047     sess.run(init)
1048 
1049     # Run the training for as many cycles as requested on the command line.
1050     for i in range(FLAGS.how_many_training_steps):
1051       # Get a batch of input bottleneck values, either calculated fresh every
1052       # time with distortions applied, or from the cache stored on disk.
1053       if do_distort_images:
1054         (train_bottlenecks,
1055          train_ground_truth) = get_random_distorted_bottlenecks(
1056              sess, image_lists, FLAGS.train_batch_size, 'training',
1057              FLAGS.image_dir, distorted_jpeg_data_tensor,
1058              distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
1059       else:
1060         (train_bottlenecks,
1061          train_ground_truth, _) = get_random_cached_bottlenecks(
1062              sess, image_lists, FLAGS.train_batch_size, 'training',
1063              FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1064              decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1065              FLAGS.architecture)
1066       # Feed the bottlenecks and ground truth into the graph, and run a training
1067       # step. Capture training summaries for TensorBoard with the `merged` op.
1068       train_summary, _ = sess.run(
1069           [merged, train_step],
1070           feed_dict={bottleneck_input: train_bottlenecks,
1071                      ground_truth_input: train_ground_truth})
1072       train_writer.add_summary(train_summary, i)
1073 
1074       # Every so often, print out how well the graph is training.
1075       is_last_step = (i + 1 == FLAGS.how_many_training_steps)
1076       if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
1077         train_accuracy, cross_entropy_value = sess.run(
1078             [evaluation_step, cross_entropy],
1079             feed_dict={bottleneck_input: train_bottlenecks,
1080                        ground_truth_input: train_ground_truth})
1081         tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
1082                         (datetime.now(), i, train_accuracy * 100))
1083         tf.logging.info('%s: Step %d: Cross entropy = %f' %
1084                         (datetime.now(), i, cross_entropy_value))
1085         validation_bottlenecks, validation_ground_truth, _ = (
1086             get_random_cached_bottlenecks(
1087                 sess, image_lists, FLAGS.validation_batch_size, 'validation',
1088                 FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1089                 decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1090                 FLAGS.architecture))
1091         # Run a validation step and capture training summaries for TensorBoard
1092         # with the `merged` op.
1093         validation_summary, validation_accuracy = sess.run(
1094             [merged, evaluation_step],
1095             feed_dict={bottleneck_input: validation_bottlenecks,
1096                        ground_truth_input: validation_ground_truth})
1097         validation_writer.add_summary(validation_summary, i)
1098         tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
1099                         (datetime.now(), i, validation_accuracy * 100,
1100                          len(validation_bottlenecks)))
1101 
1102       # Store intermediate results
1103       intermediate_frequency = FLAGS.intermediate_store_frequency
1104 
1105       if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
1106           and i > 0):
1107         intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
1108                                   'intermediate_' + str(i) + '.pb')
1109         tf.logging.info('Save intermediate result to : ' +
1110                         intermediate_file_name)
1111         save_graph_to_file(sess, graph, intermediate_file_name)
1112 
1113     # We've completed all our training, so run a final test evaluation on
1114     # some new images we haven't used before.
1115     test_bottlenecks, test_ground_truth, test_filenames = (
1116         get_random_cached_bottlenecks(
1117             sess, image_lists, FLAGS.test_batch_size, 'testing',
1118             FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1119             decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1120             FLAGS.architecture))
1121     test_accuracy, predictions = sess.run(
1122         [evaluation_step, prediction],
1123         feed_dict={bottleneck_input: test_bottlenecks,
1124                    ground_truth_input: test_ground_truth})
1125     tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
1126                     (test_accuracy * 100, len(test_bottlenecks)))
1127 
1128     if FLAGS.print_misclassified_test_images:
1129       tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
1130       for i, test_filename in enumerate(test_filenames):
1131         if predictions[i] != test_ground_truth[i].argmax():
1132           tf.logging.info('%70s  %s' %
1133                           (test_filename,
1134                            list(image_lists.keys())[predictions[i]]))
1135 
1136     # Write out the trained graph and labels with the weights stored as
1137     # constants.
1138     save_graph_to_file(sess, graph, FLAGS.output_graph)
1139     with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
1140       f.write('\n'.join(image_lists.keys()) + '\n')
1141 
1142 
1143 if __name__ == '__main__':
1144   parser = argparse.ArgumentParser()
1145   parser.add_argument(
1146       '--image_dir',
1147       type=str,
1148       default='data/train',
1149       help='Path to folders of labeled images.'
1150   )
1151   parser.add_argument(
1152       '--output_graph',
1153       type=str,
1154       default='tmp/output_graph.pb',
1155       help='Where to save the trained graph.'
1156   )
1157   parser.add_argument(
1158       '--intermediate_output_graphs_dir',
1159       type=str,
1160       default='tmp/intermediate_graph/',
1161       help='Where to save the intermediate graphs.'
1162   )
1163   parser.add_argument(
1164       '--intermediate_store_frequency',
1165       type=int,
1166       default=0,
1167       help="""\
1168          How many steps to store intermediate graph. If "0" then will not
1169          store.\
1170       """
1171   )
1172   parser.add_argument(
1173       '--output_labels',
1174       type=str,
1175       default='tmp/output_labels.txt',
1176       help='Where to save the trained graph\'s labels.'
1177   )
1178   parser.add_argument(
1179       '--summaries_dir',
1180       type=str,
1181       default='tmp/retrain_logs',
1182       help='Where to save summary logs for TensorBoard.'
1183   )
1184   parser.add_argument(
1185       '--how_many_training_steps',
1186       type=int,
1187       default=200,
1188       help='How many training steps to run before ending.'
1189   )
1190   parser.add_argument(
1191       '--learning_rate',
1192       type=float,
1193       default=0.01,
1194       help='How large a learning rate to use when training.'
1195   )
1196   parser.add_argument(
1197       '--testing_percentage',
1198       type=int,
1199       default=10,
1200       help='What percentage of images to use as a test set.'
1201   )
1202   parser.add_argument(
1203       '--validation_percentage',
1204       type=int,
1205       default=10,
1206       help='What percentage of images to use as a validation set.'
1207   )
1208   parser.add_argument(
1209       '--eval_step_interval',
1210       type=int,
1211       default=10,
1212       help='How often to evaluate the training results.'
1213   )
1214   parser.add_argument(
1215       '--train_batch_size',
1216       type=int,
1217       default=100,
1218       help='How many images to train on at a time.'
1219   )
1220   parser.add_argument(
1221       '--test_batch_size',
1222       type=int,
1223       default=-1,
1224       help="""\
1225       How many images to test on. This test set is only used once, to evaluate
1226       the final accuracy of the model after training completes.
1227       A value of -1 causes the entire test set to be used, which leads to more
1228       stable results across runs.\
1229       """
1230   )
1231   parser.add_argument(
1232       '--validation_batch_size',
1233       type=int,
1234       default=100,
1235       help="""\
1236       How many images to use in an evaluation batch. This validation set is
1237       used much more often than the test set, and is an early indicator of how
1238       accurate the model is during training.
1239       A value of -1 causes the entire validation set to be used, which leads to
1240       more stable results across training iterations, but may be slower on large
1241       training sets.\
1242       """
1243   )
1244   parser.add_argument(
1245       '--print_misclassified_test_images',
1246       default=False,
1247       help="""\
1248       Whether to print out a list of all misclassified test images.\
1249       """,
1250       action='store_true'
1251   )
1252   parser.add_argument(
1253       '--model_dir',
1254       type=str,
1255       default='tmp/imagenet',
1256       help="""\
1257       Path to classify_image_graph_def.pb,
1258       imagenet_synset_to_human_label_map.txt, and
1259       imagenet_2012_challenge_label_map_proto.pbtxt.\
1260       """
1261   )
1262   parser.add_argument(
1263       '--bottleneck_dir',
1264       type=str,
1265       default='tmp/bottleneck',
1266       help='Path to cache bottleneck layer values as files.'
1267   )
1268   parser.add_argument(
1269       '--final_tensor_name',
1270       type=str,
1271       default='final_result',
1272       help="""\
1273       The name of the output classification layer in the retrained graph.\
1274       """
1275   )
1276   parser.add_argument(
1277       '--flip_left_right',
1278       default=False,
1279       help="""\
1280       Whether to randomly flip half of the training images horizontally.\
1281       """,
1282       action='store_true'
1283   )
1284   parser.add_argument(
1285       '--random_crop',
1286       type=int,
1287       default=0,
1288       help="""\
1289       A percentage determining how much of a margin to randomly crop off the
1290       training images.\
1291       """
1292   )
1293   parser.add_argument(
1294       '--random_scale',
1295       type=int,
1296       default=0,
1297       help="""\
1298       A percentage determining how much to randomly scale up the size of the
1299       training images by.\
1300       """
1301   )
1302   parser.add_argument(
1303       '--random_brightness',
1304       type=int,
1305       default=0,
1306       help="""\
1307       A percentage determining how much to randomly multiply the training image
1308       input pixels up or down by.\
1309       """
1310   )
1311   parser.add_argument(
1312       '--architecture',
1313       type=str,
1314       default='inception_v3',
1315       help="""\
1316       Which model architecture to use. 'inception_v3' is the most accurate, but
1317       also the slowest. For faster or smaller models, chose a MobileNet with the
1318       form 'mobilenet__[_quantized]'. For example,
1319       'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
1320       pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
1321       less accurate, but smaller and faster network that's 920 KB on disk and
1322       takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
1323       for more information on Mobilenet.\
1324       """)
1325   FLAGS, unparsed = parser.parse_known_args()
1326   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
View Code

 

  3.测试

  直接上代码:(路径根据个人情况修改)

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Fri Oct 13 16:15:16 2017
 4 use_output_graph
 5 使用retrain所训练的迁移后的inception模型来测试
 6 @author: Dexter
 7 """
 8 import tensorflow as tf
 9 import numpy as np
10 import os
11 from PIL import Image
12 import matplotlib.pyplot as plt
13 
14 model_name = 'tmp/output_graph.pb'
15 image_dir = 'data/validation'
16 label_filename = 'tmp/output_labels.txt'
17 
18 # 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
19 def create_graph():
20     with tf.gfile.FastGFile( model_name, 'rb') as f:
21         # 使用tf.GraphDef()定义一个空的Graph
22         graph_def = tf.GraphDef()
23         graph_def.ParseFromString(f.read())
24         # Imports the graph from graph_def into the current default Graph.
25         tf.import_graph_def(graph_def, name='')
26 
27 # 读取标签labels
28 def load_labels(label_file_dir):
29     if not tf.gfile.Exists(label_file_dir):
30         # 预先检测地址是否存在
31         tf.logging.fatal('File does not exist %s', label_file_dir)
32     else:
33         # 读取所有的标签返并回一个list
34         labels = tf.gfile.GFile(label_file_dir).readlines()
35         for i in range(len(labels)):
36             labels[i] = labels[i].strip('\n')
37     return labels
38 
39 # 创建graph
40 create_graph()
41 
42 # 创建会话,因为是从已有的Inception_v3模型中恢复,所以无需初始化
43 with tf.Session() as sess:
44     # Inception_v3模型的最后一层final_result:0的输出
45     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
46     
47     # 遍历目录
48     for root, dirs, files in os.walk(image_dir):
49         for file in files:
50             # 载入图片
51             image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
52             # 输入图像(jpg格式)数据,得到softmax概率值(一个shape=(1,1008)的向量)
53             predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
54             # 将结果转为1维数据
55             predictions = np.squeeze(predictions)
56     
57             # 打印图片路径及名称
58             image_path = os.path.join(root, file)
59             print(image_path)
60             # 显示图片
61             img = Image.open(image_path)
62             plt.imshow(img)
63             plt.axis('off')
64             plt.show()
65             
66             # 排序,取出前5个概率最大的值(top-5),本数据集一共就5个
67             # argsort()返回的是数组值从小到大排列所对应的索引值
68             top_5 = predictions.argsort()[-5:][::-1]
69             for label_index in top_5:
70                 # 获取分类名称
71                 label_name = load_labels(label_filename)[label_index]
72                 # 获取该分类的置信度
73                 label_score = predictions[label_index]
74                 print('%s (score = %.5f)' % (label_name, label_score))
75             print()
View Code

  完。

  

 

  

转载于:https://www.cnblogs.com/EstherLjy/p/9861034.html

你可能感兴趣的:(【学习笔记】Tensorflow+Inception-v3训练自己的数据)