在tensorflow/slim下调用pb文件进行图像识别的预测

在tensorflow/slim下调用pb文件进行图像识别的预测

  • 方法1
  • 方法2

该教程提供了在tensorflow/slim框架下使用训练好的图像分类的pb文件进行图像分类识别预测的代码,共提供了2种方法。

方法1

建立test_image_classifier.py文件:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
  'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
  'checkpoint_path', '/tmp/tfmodel/',
  'The directory where the model was written to or an absolute path to a '
  'checkpoint file.')

tf.app.flags.DEFINE_string(
  'test_path', '', 'Test image path.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_integer(
  'labels_offset', 0,
  'An offset for the labels in the dataset. This flag is primarily used to '
  'evaluate the VGG and ResNet architectures which do not use a background '
  'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
  'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
  'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
  'test_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  # if not FLAGS.test_list:
  #   raise ValueError('You must supply the test list with --test_list')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
      is_training=False)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=False)

    test_image_size = FLAGS.test_image_size or network_fn.default_image_size

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.Graph().as_default()
    with tf.Session() as sess:
      image = open(FLAGS.test_path, 'rb').read()
      image = tf.image.decode_jpeg(image, channels=3)
      processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
      processed_images = tf.expand_dims(processed_image, 0)
      logits, _ = network_fn(processed_images)
      predictions = tf.argmax(logits, 1)
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      np_image, network_input, predictions = sess.run([image, processed_image, predictions])
      print('{} {}'.format(FLAGS.test_path, predictions[0]))

if __name__ == '__main__':
  tf.app.run()

建立classify_image.sh文件,调用test_image_classifier.py。

python test_image_classifier.py \
  --checkpoint_path=train_logs/ \
  --test_path=./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg \
  --num_classes=5 \
  --model_name=inception_resnet_v2

运行

bash classify_image.sh

方法2

建立demo_img.py文件:

import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
#from IPython import display
#import pylab
#import PIL
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import matplotlib.font_manager as fm

image_dir='./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg'
dataset_dir='./data/flower_photos'
model_dir ='./output_model_pb/frozen_graph.pb'



#opencv
class TOD(object):
  def __init__(self):
    self.PATH_TO_CKPT = './output_model_pb/frozen_graph.pb'
    self.NUM_CLASSES = 5
    self.detection_graph = self._load_model()
    self.label_map = dataset_utils.read_label_file(dataset_dir)
  def _load_model(self):
    detection_graph = tf.Graph()
    with detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    return detection_graph

  def visualization(self,image,str):
    image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
    draw = ImageDraw.Draw(image_pil)
    font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15)  # 设置字体DejaVu Sans
    draw.text((10, 10), str, 'red', font)  # 'fuchsia'
    np.copyto(image, np.array(image_pil))
    return image

  def detect(self,image,resized):
    with self.detection_graph.as_default():
      with tf.Session(graph=self.detection_graph) as sess:
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(resized, axis=0)
        inp = self.detection_graph.get_tensor_by_name('input:0')
        #predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Predictions/Reshape_1:0')
        predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
        x = predictions.eval(feed_dict={inp: image_np_expanded})
        font1 = str(self.label_map[x.argmax()])
        font2 = str(x.max())
        font3 = font1 + ":" + font2
        image = self.visualization(image,font3)
        #print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max())

    cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
    cv2.imshow("detection", image)
    cv2.waitKey(0)



if __name__ == '__main__':
  image = cv2.imread(image_dir)
  # dst=cv2.cvtColor(src,cv2.COLOR_BGR2GRAY)
  width = 299
  height = 299
  dim = (width, height)
  # resize image to [-1,1] Maps pixel values to the range [-1, 1]
  resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1
  detecotr = TOD()
  detecotr.detect(image,resized)

其中用到的labels.txt文件的格式为:

0:daisy
1:dandelion
2:roses
3:sunflowers
4:tulips

运行:

python demo_img.py

你可能感兴趣的:(代码,tensorflow,图像识别,python,opencv)