部署上线tensorflow的pb模型

将tensorflow的ckpt模型转化为pb模型,可以大大提高网络预测速度,是进行部署的第一步。怎么做参考:这里。我看网上资料较少,我写一下怎么读取pb模型进行测试,通常落地会采用c++这种更底层的语言。

具体怎么写需要根据网络的测试代码来写,每个网络输入输出不一样,我在下面贴一个写好的只作为参考。

总体步骤:

1.读入pb文件

def freeze_graph_test2(pb_path, test_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())


这部分代码直接复制,tf的常规操作,把pb_path替换成自己的pb文件地址就行
2.定义输入

  keep_probability = sess.graph.get_tensor_by_name(name="keep_probabilty:0")
            image =sess.graph.get_tensor_by_name(name="input_image:0")
            _, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)


去看一下自己网络的测试代码,比如我的代码是这样:


所以用get_tensor_by_name取出这几个节点,注意后面必须加上索引号,如:0
3.定义输出

pred_annotation = sess.graph.get_tensor_by_name("inference/prediction:0")


inference/prediction:0是网络最后一层的名字
4.运行
到这步就跟测试代码一样就行了
 

pred = sess.run(pred_annotation, feed_dict={image: valid_images,keep_probability: 1.0})

所有代码:

import tensorflow as tf
from tensorflow.python.framework import graph_util
import os
import time
from datetime import timedelta
import numpy as np
import TensorflowUtils as utils
import read_MITSceneParsingData as scene_parsing
import datetime
import BatchDatsetReader as dataset
from six.moves import xrange

IMAGE_SIZE = 448

def freeze_graph_test2(pb_path, test_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            keep_probability = sess.graph.get_tensor_by_name(name="keep_probabilty:0")
            image =sess.graph.get_tensor_by_name(name="input_image:0")
            _, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
            pred_annotation = sess.graph.get_tensor_by_name("inference/prediction:0")
            

            image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
            validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)
            sess = tf.Session()
            valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)


            pred = sess.run(pred_annotation, feed_dict={image: valid_images,keep_probability: 1.0})
            print("len:",valid_annotations[0].shape)
            valid_annotations = np.squeeze(valid_annotations, axis=3)
            for itr in range(FLAGS.batch_size):
                utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
                utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
                utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
                       
                print("Saved image: %d" % itr)
if __name__ == '__main__':
    out_pb_path = "../checkpoints/frozen_model.pb"
    test_dir = "data/cnews/cnews.test.txt"
    freeze_graph_test2(pb_path=out_pb_path,test_path=test_dir)
    

 

你可能感兴趣的:(AR深度学习项目)