该教程提供了在tensorflow/slim框架下使用训练好的图像分类的pb文件进行图像分类识别预测的代码,共提供了2种方法。
建立test_image_classifier.py文件:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'test_path', '', 'Test image path.')
tf.app.flags.DEFINE_integer(
'num_classes', 5, 'Number of classes.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'test_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
# if not FLAGS.test_list:
# raise ValueError('You must supply the test list with --test_list')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
is_training=False)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
test_image_size = FLAGS.test_image_size or network_fn.default_image_size
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.Graph().as_default()
with tf.Session() as sess:
image = open(FLAGS.test_path, 'rb').read()
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_images = tf.expand_dims(processed_image, 0)
logits, _ = network_fn(processed_images)
predictions = tf.argmax(logits, 1)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
np_image, network_input, predictions = sess.run([image, processed_image, predictions])
print('{} {}'.format(FLAGS.test_path, predictions[0]))
if __name__ == '__main__':
tf.app.run()
建立classify_image.sh文件,调用test_image_classifier.py。
python test_image_classifier.py \
--checkpoint_path=train_logs/ \
--test_path=./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg \
--num_classes=5 \
--model_name=inception_resnet_v2
运行
bash classify_image.sh
建立demo_img.py文件:
import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
#from IPython import display
#import pylab
#import PIL
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import matplotlib.font_manager as fm
image_dir='./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg'
dataset_dir='./data/flower_photos'
model_dir ='./output_model_pb/frozen_graph.pb'
#opencv
class TOD(object):
def __init__(self):
self.PATH_TO_CKPT = './output_model_pb/frozen_graph.pb'
self.NUM_CLASSES = 5
self.detection_graph = self._load_model()
self.label_map = dataset_utils.read_label_file(dataset_dir)
def _load_model(self):
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
def visualization(self,image,str):
image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
draw = ImageDraw.Draw(image_pil)
font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15) # 设置字体DejaVu Sans
draw.text((10, 10), str, 'red', font) # 'fuchsia'
np.copyto(image, np.array(image_pil))
return image
def detect(self,image,resized):
with self.detection_graph.as_default():
with tf.Session(graph=self.detection_graph) as sess:
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(resized, axis=0)
inp = self.detection_graph.get_tensor_by_name('input:0')
#predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Predictions/Reshape_1:0')
predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
x = predictions.eval(feed_dict={inp: image_np_expanded})
font1 = str(self.label_map[x.argmax()])
font2 = str(x.max())
font3 = font1 + ":" + font2
image = self.visualization(image,font3)
#print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max())
cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
cv2.imshow("detection", image)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread(image_dir)
# dst=cv2.cvtColor(src,cv2.COLOR_BGR2GRAY)
width = 299
height = 299
dim = (width, height)
# resize image to [-1,1] Maps pixel values to the range [-1, 1]
resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1
detecotr = TOD()
detecotr.detect(image,resized)
其中用到的labels.txt文件的格式为:
0:daisy
1:dandelion
2:roses
3:sunflowers
4:tulips
运行:
python demo_img.py