如图所示,
.meta
– 保存图结构,即神经网络的网络结构
.data
– 保存数据文件,即网络的权值,偏置,操作等等
.index
– 是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
checkpoint
– 文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model.
保存模型:
saver = tf.train.Saver()
saver.save(sess, model_path)
其中model_path
是模型保存路径。
在工程中,我们往往需要将模型和权重固化,便于发布和预测。
使用tensorFlow
官方提供的freeze_graph.py
工具来保存相应模型。(代码中把freeze_graph.py
文件放在commom.utils.tf
路径下导入)
freeze_graph.py
先加载模型文件,从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重常量,然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)。
from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
sess.graph.as_graph_def(),
os.path.join(model_path),
GRAPH_PB_NAME,
as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
input_graph=os.path.join(model_path, GRAPH_PB_NAME),
input_saver=False,
input_binary=True,
input_checkpoint=os.path.join(model_path, CHECKPOINT_PREFIX),
output_node_names="viterbi_sequence,intent_prediction,intent_probs",
restore_op_name=None,
filename_tensor_name=None,
output_graph=os.path.join(model_path, FROZEN_GRAPH_PB_NAME),
clear_devices=False,
initializer_nodes="",
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tf.saved_model.tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2)
其中model_path
是模型保存路径,GRAPH_PB_NAME
定义了图模型的名字。
freeze_graph主要参数(参考[4]博客中的参数说明):
input_graph
: 模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary
来指定区分。input_checkpoint
: 检查点数据文件。output_node_names
: 输出节点的名字,有多个时用逗号分开。output_graph
: 保存整合后的输出模型。在我的模型中,要求的输入有四个,分别是inputs_vocab
,inputs_feature_list
,sequence_length
,max_length
。
计算得到的输出有两个viterbi_sequence
和intent_prediction
。
ckpt = tf.train.get_checkpoint_state(arg.model_path + '/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()
# 加载模型中的操作节点
inputs_vocab = graph.get_operation_by_name('inputs_vocab').outputs[0]
feature_data_list = graph.get_operation_by_name('inputs_feature_list').outputs[0]
sequence_length = graph.get_operation_by_name('sequence_length').outputs[0]
max_length = graph.get_operation_by_name('max_length').outputs[0]
# 准备测试数据(略)
# in_data = ...
# fea_data_list = ...
# length = ...
# max_len = ...
# feed 数据
feed_dict = {inputs_vocab.name: in_data,
feature_data_list.name: fea_data_list,
sequence_length.name: length,
max_length.name: max_len}
# 计算
viterbi_sequence = sess.run('viterbi_sequence:0', feed_dict)
intent_prediction = sess.run('intent_prediction:0', feed_dict)
# 读取图文件
with tf.gfile.FastGFile('./model/frozen_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# We load the graph_def in the default graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
with tf.Session() as sess:
# 根据名称返回tensor数据
inputs_vocab = graph.get_tensor_by_name('inputs_vocab:0')
feature_data_list = graph.get_tensor_by_name('inputs_feature_list:0')
sequence_length = graph.get_tensor_by_name('sequence_length:0')
max_length = graph.get_tensor_by_name('max_length:0')
# 准备测试数据(略)
# in_data = ...
# fea_data_list = ...
# length = ...
# max_len = ...
# feed 数据
feed_dict = {inputs_vocab.name: in_data,
feature_data_list.name: fea_data_list,
sequence_length.name: length,
max_length.name: max_len}
# 计算结果
viterbi_sequence = graph.get_tensor_by_name('viterbi_sequence:0')
intent_prediction = graph.get_tensor_by_name('intent_prediction:0')
viterbi_sequence = sess.run(viterbi_sequence, feed_dict)
intent_prediction = sess.run(intent_prediction, feed_dict)
注意,这里如果不使用上下文管理器Graph().as_default()
,在进行预测的时候可能会报"The Session graph is empty. Add operations to the graph before calling run()…"的错误。