tensorflow官方教程:运用模型对类别进行预测

tensorflow官方教程:运用模型对类别进行预测


本文主要包含如下内容:

  • tensorflow官方教程运用模型对类别进行预测
    • python版本
    • C代码

  本教程将会教你如何使用Inception-v3。你将学会如何用Python或者C++把图像分为1000个类别.


python版本


  本段代码为tensorflow的教程代码.在开始运用模型Inception-v3对图像类别进行预测之前, 需要下载tensorflow/model.
  该python代码位于:models/tutorials/image/imagenet/classify_image.py中,执行代码即可进行预测:

cd models/tutorials/image/imagenet
python classify_image.py

# 测试结果如下:
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.89107)
indri, indris, Indri indri, Indri brevicaudatus (score = 0.00779)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00296)
custard apple (score = 0.00147)
earthstar (score = 0.00117)

  
  其中, classify_image.py的核心代码为加载模型/前向传播预测结果

def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""       # 加载模型
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image):      # 前向传播网络,预测图像类别
  """Runs inference on an image.

  Args:
    image: Image file name.

  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()       # 读入图像数据

  # Creates graph from saved GraphDef.      加载模型
  create_graph()

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')         # 捕获输出
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})       # 前向传播
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()      # 获得ID对应类别

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))        # 打印结果

# 代码思路:首先读入输入图像,加载测试模型,然后前向传播捕获对应输出,并打印对应结果。

C++代码


  对应的C++代码位于/tensorflow/tensorflow/examples/label_image/main.cc   参考网站
  要完成对图像的预测,首先需要下载模型,将网址复制到网站上下载网络模型, 然后将其解压到指定目录:

    https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
    tar -zxvf inception_v3_2016_08_28_frozen.pb.tar.gz -C tensorflow/examples/label_image/data

  接下来,运用tensorflow源码进行编译,在终端中编译例子步,生成并执行二进制可执行文件:

    bazel build tensorflow/examples/label_image/...
    bazel-bin/tensorflow/examples/label_image/label_image

  它使用了框架自带的示例图片,输出的结果大致是这样:

    I tensorflow/examples/label_image/main.cc:250] military uniform (653): 0.834306
    I tensorflow/examples/label_image/main.cc:250] mortarboard (668): 0.0218695
    I tensorflow/examples/label_image/main.cc:250] academic gown (401): 0.0103581
    I tensorflow/examples/label_image/main.cc:250] pickelhaube (716): 0.00800814
    I tensorflow/examples/label_image/main.cc:250] bulletproof vest (466): 0.00535085

  这里,我们使用的默认图像是 Admiral Grace Hopper,网络模型正确地识别出她穿着一套军服,分数高达0.8。
  

你可能感兴趣的:(Deep-Learning)