建立与读取.pb文件

#coding=utf-8
import tensorflow as tf
from tensorflow.python.framework import graph_util

x = tf.placeholder(shape=[1], dtype=tf.float32, name='x')

varibale_1 = tf.get_variable('v1', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))

output = tf.multiply(x, varibale_1, name='mul')

initial_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(initial_op)
    graph_def = tf.get_default_graph().as_graph_def()#将图定义取出
#     print(graph_def)
    out_graph = graph_util.convert_variables_to_constants(sess, graph_def, ['mul'])#将图中的变量转化为constant
    print(sess.run(output,{x:[5]}))
    print(sess.run(varibale_1))
    with tf.gfile.GFile('./model.pb','wb') as f:
        f.write(out_graph.SerializeToString())#将图定义转化为字符串形式并且写入.pb文件中
结果:


读取.pb文件:

#coding=utf-8
import tensorflow as tf
from tensorflow.python.platform import gfile

k = tf.constant([1, 2, 3], dtype=tf.float32)


with tf.Session() as sess:
    model_filename = 'model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:#打开.pb文件
        graph_def = tf.GraphDef()#建立一个图定义类
        print(graph_def)
        graph_def.ParseFromString(f.read())#将.pb文件中的信息写入该图定义类
        
    v1= tf.import_graph_def(graph_def, return_elements=[ 'v1:0'])#载入图定义,并返回感兴趣的值
    print(tf.get_default_graph().as_graph_def())
    print(tf.get_default_graph().get_tensor_by_name('import/x:0'))
    print(v1.name)



你可能感兴趣的:(tensorflow)