cifar10数据集介绍[官网]
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
CIFAR-10数据集包含了10个分类60000张32x32彩色图片,每个类型有6000张图片.50000用于训练10000用于测试.
数据集介绍到这里,可以直接将数据集下载然后用pikle读取,然后生成tfrecord格式的文件,tensorflow的slim框架已经将这些帮我们做了,所以本文只介绍tensorflow下slim是如何处理cifar-10数据的.
master/research/slim/scripts下执行脚本train_cifarnet_on_cifar10.sh
注意
<1>环境变量
TRAIN_DIR 存储数据
DATASET_DIR 存储模型
<2>python or python3以及是否使用gpu,如果不使用需要修改clone_on_cpu true
当train_cifarnet_on_cifar10.sh执行完毕
TRAIN_DIR对应目录会有相应的下载好的数据生成:cifar10_test.tfrecord cifar10_train.tfrecord labels.txt
DATASET_DIR对应目录会有训练生成的模型checkpoint model.ckpt-100000.data-00000-of-00001 model.ckpt-100000.index model.ckpt-100000.meta
eval/Recall_5[0.993]
eval/Accuracy[0.8539]
原代码
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size,
image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
修改如下:
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
preprocessing_name = FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
#与原来相比这里更改了placeholder的维度使其的input可以只接受一张图片
# placeholder = tf.placeholder(name='input', dtype=tf.float32,
# shape=[image_size,
# image_size, 3])
placeholder = tf.placeholder(name='input',dtype=tf.string)
#解码
image = tf.image.decode_jpeg(placeholder,channels=3)
#对数据进行预处理
image = image_preprocessing_fn(image,image_size,image_size)
#为了满足网络计算的要求,给x扩维,增加一个维度
x = tf.expand_dims(image,axis=0)
#x =tf.expand_dims(placeholder,axis=0)
logits,end_points = network_fn(x)
prediction = tf.nn.softmax(logits,name='output')
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
导出模型所用的命令如下:
#导出模型
python3 export_inference_graph.py \
--model_name=cifarnet \
--batch_size=1 \
--dataset_name=cifar10 \
--output_file=cifarnet_graph_def.pb \
--dataset_dir=./cifar10/
关键代码如下:以后研究
freeze_graph.py
if input_meta_graph_def:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_meta_graph_def.graph_def,
output_node_names.split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
else:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
命令如下
#冻结模型
python3 freeze_graph.py \
--input_graph=cifarnet_graph_def.pb \
--input_binary=true \#注意这里二进制的方式否则会报error
--input_checkpoint="./cifarnet-model/model.ckpt-100000" \
--output_graph=freezed_cifarnet.pb \
--output_node_names=output
#设置为out的原因prediction = tf.nn.softmax(logits,name='output')
全部代码如下
"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_path=None):
if not label_path:
tf.logging.fatal('please specify the label file.')
return
self.node_lookup = self.load(label_path)
def load(self, label_path):
"""Loads a human readable English name for each softmax node.
Args:
label_lookup_path: string UID to integer node ID.
uid_lookup_path: string UID to human-readable string.
Returns:
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(label_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(label_path).readlines()
id_to_human = {}
for line in proto_as_ascii_lines:
if line.find(':') < 0:
continue
_id, human = line.rstrip('\n').split(':')
id_to_human[int(_id)] = human
return id_to_human
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def create_graph(model_file=None):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
if not model_file:
model_file = FLAGS.model_file
with open(model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def run_inference_on_image(image, model_file=None):
"""Runs inference on an image.
Args:
image: Image file name.
Returns:
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = open(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph(model_file)
with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('output:0')
predictions = sess.run(softmax_tensor,
{'input:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = NodeLookup(FLAGS.label_file)
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
top_names = []
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
top_names.append(human_string)
score = predictions[node_id]
print('id:[%d] name:[%s] (score = %.5f)' % (node_id, human_string, score))
return predictions, top_k, top_names
def main(_):
image = (FLAGS.image_file if FLAGS.image_file else
os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
run_inference_on_image(image)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# classify_image_graph_def.pb:
# Binary representation of the GraphDef protocol buffer.
# imagenet_synset_to_human_label_map.txt:
# Map from synset ID to a human readable string.
# imagenet_2012_challenge_label_map_proto.pbtxt:
# Text representation of a protocol buffer mapping a label to synset ID.
parser.add_argument(
'--model_file',
type=str,
default='/tmp/imagenet',
help="""\
Path to the .pb file that contains the frozen weights. \
"""
)
parser.add_argument(
'--label_file',
type=str,
default='',
help='Absolute path to label file.'
)
parser.add_argument(
'--image_file',
type=str,
default='',
help='Absolute path to image file.'
)
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
执行命令
python3 classify_image.py \
--model_file=./freezed_cifarnet.pb \
--image_file=./timg.jpeg \
--label_file=./cifar10/labels.txt
结果如下
id:[2] name:[bird] (score = 0.74933)
id:[3] name:[cat] (score = 0.09537)
id:[4] name:[deer] (score = 0.09519)
id:[0] name:[airplane] (score = 0.02756)
id:[1] name:[automobile] (score = 0.01199)