TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph,所以类初始化操作应在搭建Graph时完成。TensorFlow会将变量保存在二进制checkpoint文件中,这类文件会将变量名称映射到张量值。
tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
Saver.save(
sess, # 必需参数,Session对象
save_path, # 必需参数,存储路径
global_step=None, # 可以是Tensor, Tensor name, 整型数
latest_filename=None, # 协议缓冲文件名,默认为'checkpoint',不用管
meta_graph_suffix='meta', # 图文件的后缀,默认为'.meta',不用管
write_meta_graph=True, # 是否保存Graph
write_state=True, # 建议选择默认值True
strip_default_attrs=False # 是否跳过具有默认值的节点
保存变量的例子:
创建变量
初始化变量
实例化tf.train.Saver()
创建Session并保存
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1_name", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2_name", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "model/model.ckpt") # 返回一个保存路径的字符串
保存路径中的文件为:
checkpoint:保存当前网络状态的文件
model.data-00000-of-00001
model.index
model.meta:保存Graph结构的文件
可以发现,没有名为 ‘model/model.ckpt’ 的实体文件,表示model目录下所有的文件,其中 ‘model’ 是一个与用户交互的前缀。
从checkpoint文件中提取变量值赋给新定义的变量。
tf.reset_default_graph()
# Create some variables
# !!!variable name必须与保存时的name一致
v1 = tf.get_variable("v1_name", shape=[3])
v2 = tf.get_variable("v2_name", shape=[5])
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk
saver.restore(sess, "model/model.ckpt")
print("v1: %s" % v1.eval())
print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 1. 1. 1.]
v2: [-1. -1. -1. -1. -1.]
'''
向tf.train.Saver()的构造函数传递以下任意内容来轻松指定要保存或加载的名称和变量:
变量列表(要求变量与变量名之间的一一对应)
Python字典,其中,key是要使用的名称,value是要管理的变量(通过键值映射自定义变量与变量名之间的对应关系)
tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2_name": v2})
with tf.Session() as sess:
v1.initializer.run()
saver.restore(sess, "model/model.ckpt")
print("v1: %s" % v1.eval())
print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 0. 0. 0.]
v2: [-1. -1. -1. -1. -1.]
'''
Note:1 如果需要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的Saver对象。同一个变量可以列在多个Saver对象中。
2 变量的值只有在Saver.restore()方法运行时才会更改,这些变量不需要初始化。
3 如果您仅在会话开始时恢复模型变量的子集,则必须为其他变量运行初始化 op。
使用 inspect_checkpoint 库快速检查某个检查点的变量。
Args Description
file_name Name of the checkpoint file
tensor_name Name of the tensor in the checkpoint file
all_tensors Boolean indicating whether to print all tensors
all_tensor_names Boolean indicating whether to print all tensor names
from tensorflow.python.tools import inspect_checkpoint as ickpt
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=False)
'''output
tensor_name: v1_name
[ 1. 1. 1.]
'''
[TensorFlow保存和恢复变量——tf.train.Saver()]
[tensorflow训练好的模型怎么调用?]
-柚子皮-
引入main.py中写好的model_fn,写一个输入接收器函数serving_input_receiver_fn后,estimator.export_saved_model。
DATADIR = '../../data/example'
PARAMS = './results/params.json'
MODELDIR = './results/model'
def serving_input_receiver_fn():
"""Serving input_fn that builds features from placeholders
Returns
-------
tf.estimator.export.ServingInputReceiver
"""
words = tf.placeholder(dtype=tf.string, shape=[None, None], name='words')
nwords = tf.placeholder(dtype=tf.int32, shape=[None], name='nwords')
receiver_tensors = {'words': words, 'nwords': nwords}
features = {'words': words, 'nwords': nwords}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
if __name__ == '__main__':
with open(PARAMS, 'r') as f:
params = json.load(f)
estimator = tf.estimator.Estimator(model_fn, MODELDIR, params=params)
estimator.export_saved_model('saved_model', serving_input_receiver_fn)
在export_dir中找到最新的输出模型,predictor.from_saved_model即可。
"""Reload and serve a saved model"""
import os
from tensorflow.contrib import predictor
params = {
'lang': 'chn',
}
LINE = '输入的句子'
export_dir = 'saved_model'
if __name__ == '__main__':
subdirs = [os.path.join(export_dir, x) for x in os.listdir(export_dir) if
os.path.isdir(os.path.join(export_dir, x)) and 'temp' not in str(x)]
latest = str(sorted(subdirs)[-1])
predict_fn = predictor.from_saved_model(latest)
if params['lang'] == 'chn':
words = [w.encode('utf-8') for w in LINE.strip()]
else:
words = [w.encode() for w in LINE.strip().split()]
nwords = len(words)
predictions = predict_fn({'words': [words], 'nwords': [nwords]})
print(predictions)
from: -柚子皮-
ref: