[Tensorflow][原创]tensorflow保存PB模型的几种方法总结

第一种方法:(官方不推荐)

(1)引入库

from tensorflow.examples.tutorials.mnist import input_data

(2)一般在seession初始化全局变量下写这句代码

constant_graph=graph_util.convert_variables_to_constants(sess,

sess.graph_def, ['output_node_name'])

其中output_node_name是输出节点的名称,这个list可以包含输入输入多个节点名称

(3)保存模型:

with tf.gfile.FastGFile('./model.pb', mode='wb') as f:

        f.write(constant_graph.SerializeToString())

第二种方法:(这是官方推荐的)

直接保存模型:

tf.compat.v1.saved_model.simple_save(sess,

            "./saved_model",

            inputs={"input": x, 'keep_prob':keep_prob},

            outputs={"output": y_conv})

第三种方法:

# 保存图表并保存变量参数

from tensorflow.python.framework import graph_util

var_list=tf.global_variables()

constant_graph = graph_util.convert_variables_to_constants(sess,

sess.graph_def,output_node_names=[var_list[i].name for i in range(len(var_list))]) # 保存图表并保存变量参数

tf.train.write_graph(constant_graph, './output', 'expert-graph.pb', as_text=False)

具体参数看这:

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)

# Writes a graph proto to a file.

#      graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.

#      logdir: Directory where to write the graph. This can refer to remote

#        filesystems, such as Google Cloud Storage (GCS).

#      name: Filename for the graph.

#      as_text: If `True`, writes the graph as an ASCII proto.


#    Returns:

#      The path of the output proto file.

(从内置文档摘来的,相信大家都看得懂^_^)

如果要加载的话就用这个:

tf.train.import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None)

#参数如下

#meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including

#       the path) containing a `MetaGraphDef`.

#    clear_devices: Whether or not to clear the device field for an `Operation`

#        or `Tensor` during import.

#     import_scope: Optional `string`. Name scope to add. Only used when

#        initializing from protocol buffer.

#      **kwargs: Optional keyed arguments.


#    Returns:

#      A saver constructed from `saver_def` in `MetaGraphDef` or None.


#      A None value is returned if no variables exist in the `MetaGraphDef`

第四种方法:

# 只保留图表

graph_def = tf.get_default_graph().as_graph_def()

with gfile.GFile('./output/output_graph.pb', 'wb') as f:

    f.write(graph_def.SerializeToString())

# 或者

tf.train.write_graph(graph_def, './output', 'output_graph.pb', as_text=False)

你可能感兴趣的:([Tensorflow][原创]tensorflow保存PB模型的几种方法总结)