持久化的意义在于:
1. 可以保存训练的中间结果, 下次从断点开始继续训练.
2. 将模型的训练/预测/在线服务部署 独立开来.
如表格所示, 根据 model-api 与 save-format 的不同, 还有 C12∗C12=4 C 2 1 ∗ C 2 1 = 4 种搭配.
model-api | save-format |
---|---|
estimator | checkpoint |
low-level-api | saved_model |
含多个文件, 有 .meta,.data,.index
等多个后缀的文件. 变量与结构分开存储.
官方文档见参考[5].
tf.Saver
类. 用来存储与恢复网络中的变量.
Saver#save(self, sess, save_path, global_step=None, ...)
Args:
global_step
: 影响到 model.ckpt-global_step
.xxx 等文件的命名.
通过实验发现, 多次调用的话, 它会自动删除旧的数据, 只保留最新的5个版本的文件.
Saver#restore(self, sess, save_path)
Estimator
的子类的构造函数中, 有参数 model_dir
, 指定了ckpt文件的存放位置. 首次训练时, 直接创建. 后续训练或预测时, 直接加载已有的信息, 增量训练或预测.
figure estimator保存ckpt时的用法图示
对于训练好的ckpt, 我们可以恢复它的结构与权重, 送入新的数据拿相应的预测结果.
详见[7]
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
同时保存变量及 model 的结构.
官方文档见参考[1] .
signatures:一个model就像一个函数一样, 有输入有 输出,
个人喜欢用这种. 代码示例见参考[2].
tf.saved_model.builder.SavedModelBuilder(export_dir)
export_dir
参数对应的目录不能已存在.tensorflow.python.saved_model.builder_impl.SavedModelBuilder#add_meta_graph_and_variables(self, sess, tags, signature_def_map=None,...)
tags=[tag_constants.TRAINING]
. 它就是当前存储的计算图的名字, 后续加载的时候就靠名字来匹配.{str:signature_def}
形式的map.tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
Args:
tensor info
, 可用下面的build_tensor_info()
函数得到.tf.saved_model.utils.build_tensor_info(tensor)
返回的就是TensorInfo proto.
整个构建过程中有多层map.:
signature_def_map = \
{
str: build_signature_def(inputs=
{
str: build_tensor_info(input_tensor)
}
,
outputs=
{
str: build_tensor_info(output_tensor)
}
)
}
递归地查看目录下内容, 是这样的:
.pb 的意思是 protocol buffer 格式.
$ find
.
./saved_model.pb
./variables
./variables/variables.data-00000-of-00001
./variables/variables.index
MetaGraph = MetaGraphDef + signature
一个命令行工具, 用来 inspect 或 execute 你的 saved model, 见参考 [4] .
在 python 环境下, 它的位置为 \site-packages\tensorflow\python\tools\saved_model_cli.py
.
#显示帮助信息和usage
python saved_model_cli.py show -h
# 查看计算图中的所有 tag-sets
python saved_model_cli.py show --dir D:/tmp/model_save_restore/
# 根据上一步显示的tag, 查看该tag对应计算图中所有的 SignatureDef keys
python saved_model_cli.py show --dir D:/tmp/model_save_restore/ --tag_set serve
# 查看 tag 对应计算图中指定signature_def key的签名内容.
#This is very useful when you want to know the tensor key value, dtype and shape of the input tensors for executing the computation graph later.
python saved_model_cli.py show --dir D:/tmp/model_save_restore/ --tag_set serve --signature_def serving_default