第一种方法:(官方不推荐)
(1)引入库
from tensorflow.examples.tutorials.mnist import input_data
(2)一般在seession初始化全局变量下写这句代码
constant_graph=graph_util.convert_variables_to_constants(sess,
sess.graph_def, ['output_node_name'])
其中output_node_name是输出节点的名称,这个list可以包含输入输入多个节点名称
(3)保存模型:
with tf.gfile.FastGFile('./model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
第二种方法:(这是官方推荐的)
直接保存模型:
tf.compat.v1.saved_model.simple_save(sess,
"./saved_model",
inputs={"input": x, 'keep_prob':keep_prob},
outputs={"output": y_conv})
第三种方法:
# 保存图表并保存变量参数
from tensorflow.python.framework import graph_util
var_list=tf.global_variables()
constant_graph = graph_util.convert_variables_to_constants(sess,
sess.graph_def,output_node_names=[var_list[i].name for i in range(len(var_list))]) # 保存图表并保存变量参数
tf.train.write_graph(constant_graph, './output', 'expert-graph.pb', as_text=False)
具体参数看这:
tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
# Writes a graph proto to a file.
# graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
# logdir: Directory where to write the graph. This can refer to remote
# filesystems, such as Google Cloud Storage (GCS).
# name: Filename for the graph.
# as_text: If `True`, writes the graph as an ASCII proto.
# Returns:
# The path of the output proto file.
(从内置文档摘来的,相信大家都看得懂^_^)
如果要加载的话就用这个:
tf.train.import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None)
#参数如下
#meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
# the path) containing a `MetaGraphDef`.
# clear_devices: Whether or not to clear the device field for an `Operation`
# or `Tensor` during import.
# import_scope: Optional `string`. Name scope to add. Only used when
# initializing from protocol buffer.
# **kwargs: Optional keyed arguments.
# Returns:
# A saver constructed from `saver_def` in `MetaGraphDef` or None.
# A None value is returned if no variables exist in the `MetaGraphDef`
第四种方法:
# 只保留图表
graph_def = tf.get_default_graph().as_graph_def()
with gfile.GFile('./output/output_graph.pb', 'wb') as f:
f.write(graph_def.SerializeToString())
# 或者
tf.train.write_graph(graph_def, './output', 'output_graph.pb', as_text=False)