tensorflow 模型的持久化

简介

持久化的意义在于:
1. 可以保存训练的中间结果, 下次从断点开始继续训练.
2. 将模型的训练/预测/在线服务部署 独立开来.

如表格所示, 根据 model-api 与 save-format 的不同, 还有 C12C12=4 C 2 1 ∗ C 2 1 = 4 种搭配.

model-api save-format
estimator checkpoint
low-level-api saved_model

checkpoints

含多个文件, 有 .meta,.data,.index等多个后缀的文件. 变量与结构分开存储.

with low-level-api

官方文档见参考[5].

  • tf.Saver
    类. 用来存储与恢复网络中的变量.

  • Saver#save(self, sess, save_path, global_step=None, ...)
    Args:
    global_step: 影响到 model.ckpt-global_step.xxx 等文件的命名.
    通过实验发现, 多次调用的话, 它会自动删除旧的数据, 只保留最新的5个版本的文件.

  • Saver#restore(self, sess, save_path)

with estimator

Estimator的子类的构造函数中, 有参数 model_dir, 指定了ckpt文件的存放位置. 首次训练时, 直接创建. 后续训练或预测时, 直接加载已有的信息, 增量训练或预测.

figure estimator保存ckpt时的用法图示

tensorflow 模型的持久化_第1张图片
figure 本地实验, 得到的目录下内容

use it for prediction

对于训练好的ckpt, 我们可以恢复它的结构与权重, 送入新的数据拿相应的预测结果.
详见[7]

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

SavedModel

同时保存变量及 model 的结构.
官方文档见参考[1] .
signatures:一个model就像一个函数一样, 有输入有 输出,

with low-level-api

个人喜欢用这种. 代码示例见参考[2].

  • tf.saved_model.builder.SavedModelBuilder(export_dir)
    得到 builder 对象, 做后续的构建.export_dir参数对应的目录不能已存在.
  • tensorflow.python.saved_model.builder_impl.SavedModelBuilder#add_meta_graph_and_variables(self, sess, tags, signature_def_map=None,...)
    Args:
    • tags
      传的是一个集合, 如 tags=[tag_constants.TRAINING]. 它就是当前存储的计算图的名字, 后续加载的时候就靠名字来匹配.
    • signature_def_map
      计算图的签名. 拿来做预测时, 这个计算图就像是一个函数, 有输入有输出, 所以它也要有相应的 signature. 这是一个 {str:signature_def}形式的map.

signature_def 相关函数

  • tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
    Args:

    • inputs
      a proto map of string to tensor info, 可用下面的build_tensor_info()函数得到.
    • outputs
      与上面类似.
  • tf.saved_model.utils.build_tensor_info(tensor)
    返回的就是TensorInfo proto.

整个构建过程中有多层map.:

signature_def_map = \
    {
        str: build_signature_def(inputs=
                                    {
                                        str: build_tensor_info(input_tensor)
                                    }
                             ,
                                outputs=
                                    {
                                     str: build_tensor_info(output_tensor)
                                    }
        )
    }

目录结构

递归地查看目录下内容, 是这样的:
.pb 的意思是 protocol buffer 格式.

$ find
.
./saved_model.pb
./variables
./variables/variables.data-00000-of-00001
./variables/variables.index

with Estimators

MetaGraph = MetaGraphDef + signature

saved_model_cli

一个命令行工具, 用来 inspect 或 execute 你的 saved model, 见参考 [4] .
在 python 环境下, 它的位置为 \site-packages\tensorflow\python\tools\saved_model_cli.py .

常用命令

#显示帮助信息和usage
python saved_model_cli.py show -h 

# 查看计算图中的所有 tag-sets
python saved_model_cli.py show --dir D:/tmp/model_save_restore/

# 根据上一步显示的tag, 查看该tag对应计算图中所有的 SignatureDef keys 
python saved_model_cli.py show --dir D:/tmp/model_save_restore/ --tag_set serve

# 查看 tag 对应计算图中指定signature_def key的签名内容.
#This is very useful when you want to know the tensor key value, dtype and shape of the input tensors for executing the computation graph later.
python saved_model_cli.py show --dir D:/tmp/model_save_restore/ --tag_set serve --signature_def serving_default

tensorflow 模型的持久化_第2张图片
figure saved_model_cli 运行截图

参考

  1. 官方guide, Saving and Restoring
  2. 官方代码, mnist_saved_model.py
  3. 官方文档,using_savedmodel_with_estimators
  4. saved_model_cli 工具说明, cli_to_inspect_and_execute_savedmodel
  5. ckpt with low-level-api,saving_and_restoring_variables
  6. checkpoints
  7. save-restore-tensorflow-models-quick-complete-tutorial

你可能感兴趣的:(tensorflow 模型的持久化)