tesnsorflow部署.pb

免责声明:本文仅代表个人观点,如有错误,请读者自己鉴别;如果本文不小心含有别人的原创内容,请联系我删除;本人心血制作,若转载请注明出处

1、保存模型为.ckpt文件

saver = tf.train.Saver()
save_path = saver.save(sess, model_path)# 保存模型 其中model_path为模型保存的文件
保存后的模型有四个文件 checkpoint、SFCN.ckpt.data-00000-of-00001、SFCN.ckpt.index、SFCN.ckpt.meta
2、保存event, event为事件的保存路径

summary_writer = tf.summary.FileWriter(event, graph=sess.graph)

在命令窗口中输入 tensorboard --logdir==事件路径

在浏览器中可以可视化tensorboard,可以可视化tensorboard,查看图中有哪些节点

3、将ckpt文件转化为.pb文件,注意output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" 这一行要写入所有点节点,注意节点之间不可以加空格

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    # 直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用
    output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点


input_checkpoint = "./checkpoint/SFCN.ckpt"  # 输入的ckpt文件位置
output_graph = "node.pb"  # 输出节点的文件名
freeze_graph(input_checkpoint, output_graph)

4、调用.pb文件,注意y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")

这四个节点要与第3步中的节点保持一致,但是要在后面加入“:0”,如第3步中节点为“Placeholder”,第4步要写为“Placeholder:0”

global graph
graph = tf.get_default_graph()
with graph.as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        y_in = sess.graph.get_tensor_by_name("Placeholder:0")
        images_in = sess.graph.get_tensor_by_name("Placeholder_2:0")
        keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0")
        logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")
        batch_size = 1
        testTime = 0
        predictLabel = tf.zeros(trainLabel.shape)
        predictLabel = sess.run(predictLabel)
        for i in range(0, smallImage):
            realbatch_array, real_labels, real_index = getNext_batch(trainData, trainLabel, trainIndex, i)
            testStart = time.time()
            yy = sess.run(logits_out,
                          feed_dict={images_in: realbatch_array, y_in: real_labels, keep_probability_in: 1.0})
            predictLabel[i, ...] = yy
            testEnd = time.time()
            testTime1 = testEnd - testStart
            testTime += testTime1

 

你可能感兴趣的:(深度学习)