cifar10数据集下载、训练、模型导出和权重冻结以及预测

一.cifar10数据集介绍

cifar10数据集介绍[官网]

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

CIFAR-10数据集包含了10个分类60000张32x32彩色图片,每个类型有6000张图片.50000用于训练10000用于测试.

cifar10数据集下载、训练、模型导出和权重冻结以及预测_第1张图片

数据集介绍到这里,可以直接将数据集下载然后用pikle读取,然后生成tfrecord格式的文件,tensorflow的slim框架已经将这些帮我们做了,所以本文只介绍tensorflow下slim是如何处理cifar-10数据的.


二.数据下载转换以及训练和验证

master/research/slim/scripts下执行脚本train_cifarnet_on_cifar10.sh

注意

<1>环境变量

TRAIN_DIR    存储数据

DATASET_DIR   存储模型

<2>python or python3以及是否使用gpu,如果不使用需要修改clone_on_cpu true

当train_cifarnet_on_cifar10.sh执行完毕

数据下载

TRAIN_DIR对应目录会有相应的下载好的数据生成:cifar10_test.tfrecord  cifar10_train.tfrecord  labels.txt

数据训练

DATASET_DIR对应目录会有训练生成的模型checkpoint  model.ckpt-100000.data-00000-of-00001  model.ckpt-100000.index  model.ckpt-100000.meta

数据验证

eval/Recall_5[0.993]
eval/Accuracy[0.8539]

三.模型导出和权重冻结

    模型导出

原代码

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())


修改如下

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    preprocessing_name = FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    #与原来相比这里更改了placeholder的维度使其的input可以只接受一张图片
    # placeholder = tf.placeholder(name='input', dtype=tf.float32,
    #                              shape=[image_size,
    #                                     image_size, 3])
    placeholder = tf.placeholder(name='input',dtype=tf.string)
    #解码
    image = tf.image.decode_jpeg(placeholder,channels=3)
    #对数据进行预处理
    image = image_preprocessing_fn(image,image_size,image_size)
    #为了满足网络计算的要求,给x扩维,增加一个维度
    x = tf.expand_dims(image,axis=0)
    #x =tf.expand_dims(placeholder,axis=0)
    logits,end_points = network_fn(x)
    prediction = tf.nn.softmax(logits,name='output')
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString())

导出模型所用的命令如下:

#导出模型
python3 export_inference_graph.py \
--model_name=cifarnet \
--batch_size=1 \
--dataset_name=cifar10 \
--output_file=cifarnet_graph_def.pb \
--dataset_dir=./cifar10/
    

    权重冻结

关键代码如下:以后研究

freeze_graph.py

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

命令如下

#冻结模型
python3 freeze_graph.py \
--input_graph=cifarnet_graph_def.pb \
--input_binary=true \#注意这里二进制的方式否则会报error
--input_checkpoint="./cifarnet-model/model.ckpt-100000" \
--output_graph=freezed_cifarnet.pb \
--output_node_names=output
#设置为out的原因prediction = tf.nn.softmax(logits,name='output')

三.加载模型和验证结果

全部代码如下

"""Simple image classification with Inception.

Run image classification with Inception trained on ImageNet 2012 Challenge data
set.

This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.

Change the --image_file argument to any jpg image to compute a
classification of that image.

Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.

https://tensorflow.org/tutorials/image_recognition/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self,
               label_path=None):
    if not label_path:
      tf.logging.fatal('please specify the label file.')
      return
    self.node_lookup = self.load(label_path)

  def load(self, label_path):
    """Loads a human readable English name for each softmax node.

    Args:
      label_lookup_path: string UID to integer node ID.
      uid_lookup_path: string UID to human-readable string.

    Returns:
      dict from integer node ID to human-readable string.
    """
    if not tf.gfile.Exists(label_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(label_path).readlines()
    id_to_human = {}
    for line in proto_as_ascii_lines:
      if line.find(':') < 0:
        continue
      _id, human = line.rstrip('\n').split(':')
      id_to_human[int(_id)] = human

    return id_to_human

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


def create_graph(model_file=None):
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  if not model_file:
    model_file = FLAGS.model_file
  with open(model_file, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image, model_file=None):
  """Runs inference on an image.

  Args:
    image: Image file name.

  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = open(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph(model_file)

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('output:0')
    predictions = sess.run(softmax_tensor,
                           {'input:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup(FLAGS.label_file)

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    top_names = []
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      top_names.append(human_string)
      score = predictions[node_id]
      print('id:[%d] name:[%s] (score = %.5f)' % (node_id, human_string, score))
  return predictions, top_k, top_names


def main(_):
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_file',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to the .pb file that contains the frozen weights. \
      """
  )
  parser.add_argument(
      '--label_file',
      type=str,
      default='',
      help='Absolute path to label file.'
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

执行命令

python3 classify_image.py \
--model_file=./freezed_cifarnet.pb \
--image_file=./timg.jpeg \
--label_file=./cifar10/labels.txt

cifar10数据集下载、训练、模型导出和权重冻结以及预测_第2张图片

结果如下

id:[2] name:[bird] (score = 0.74933)
id:[3] name:[cat] (score = 0.09537)
id:[4] name:[deer] (score = 0.09519)
id:[0] name:[airplane] (score = 0.02756)
id:[1] name:[automobile] (score = 0.01199)

你可能感兴趣的:(深度学习)