tf.train.Saver()
是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。在tf.train.Saver()
类初始化时,用于保存和恢复的save
和restore
operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()
TensorFlow的保存和恢复分为两种:
TensorFlow会讲变量保存在二进制checkpoint文件中,这类文件会将变量名称映射到张量值。
下面是保存变量的例子:
tf.train.Saver()
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1_name", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2_name", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, “model/model.ckpt”) # 返回一个保存路径的字符串
print(“Model saved in path: %s” % save_path)
‘’‘output
Model saved in path: model/model.ckpt
‘’’
保存路径中的文件为:
- checkpoint:保存当前网络状态的文件
- model.data-00000-of-00001
- model.index
- model.meta:保存Graph结构的文件
可以发现,没有名为 ‘model/model.ckpt’ 的实体文件,其中 ‘model’ 是一个与用户交互的前缀。
关于函数saver.save()
,常用的参数就是前三个:
save(
sess, # 必需参数,Session对象
save_path, # 必需参数,存储路径
global_step=None, # 可以是Tensor, Tensor name, 整型数
latest_filename=None, # 协议缓冲文件名,默认为'checkpoint',不用管
meta_graph_suffix='meta', # 图文件的后缀,默认为'.meta',不用管
write_meta_graph=True, # 是否保存Graph
write_state=True, # 建议选择默认值True
strip_default_attrs=False # 是否跳过具有默认值的节点
从checkpoint文件中提取变量值赋给新定义的变量。
tf.reset_default_graph()
# Create some variables
# !!!variable name必须与保存时的name一致
v1 = tf.get_variable("v1_name", shape=[3])
v2 = tf.get_variable("v2_name", shape=[5])
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk
saver.restore(sess, “model/model.ckpt”)
print(“v1: %s” % v1.eval())
print(“v2: %s” % v2.eval())
‘’‘output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 1. 1. 1.]
v2: [-1. -1. -1. -1. -1.]
‘’’
variable().eval()
eval(session=None)
In a session computes and returns the value of this variable.
向tf.train.Saver()
的构造函数传递以下任意内容来轻松指定要保存或加载的名称和变量:
tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2_name": v2})
with tf.Session() as sess:
v1.initializer.run()
saver.restore(sess, "model/model.ckpt")
print("v1: %s" % v1.eval())
print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 0. 0. 0.]
v2: [-1. -1. -1. -1. -1.]
'''
variable().initializer
The initializer operation for this variable.
Note
Saver
对象。同一个变量可以列在多个Saver对象中。Saver.restore()
方法运行时才会更改,这些变量不需要初始化。我们可以使用 inspect_checkpoint 库快速检查某个检查点的变量。
prints tensors in a checkpoint file.
If no
tensor_name
is provided, prints the tensor names and shapes in
the checkpoint file.If
tensor_name
is provided, prints the content of the tensor.
Args | Description |
---|---|
file_name | Name of the checkpoint file |
tensor_name | Name of the tensor in the checkpoint file |
all_tensors | Boolean indicating whether to print all tensors |
all_tensor_names | Boolean indicating whether to print all tensor names |
from tensorflow.python.tools import inspect_checkpoint as ickpt
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=False)
'''output
tensor_name: v1_name
[ 1. 1. 1.]
'''
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=True)
'''output
tensor_name: v1_name
[ 1. 1. 1.]
tensor_name: v2_name
[-1. -1. -1. -1. -1.]
'''