TensorFlow中的saved_model模块用于生成冻结图文件,并且saved_model模块封装了平常用的Saver类。与Saver类不同的是,saved_model模块生成的模型文件集成了打标签的操作,可以更方便地部署在生产环境中。
关于为什么要用saved_model模块,这篇文章讲得挺好的。请点击这里
一个saved_model对象可以存储一个或多个MetaGraphDef。那什么时候需要多个MetaGraphDef呢?也许你想同时保存模型的CPU版本和GPU版本,或者你想同时保存模型的开发版本和生产版本。这个时候你就可以用tag(标签)来区分它们了。在加载模型的时候能根据tag标签来加载不同的MetaGraphDef。
TensorFlow中的saved_model模块可以给MetaGraphDef添加多个签名(signature)。每个签名的的结构都由输入节点、输出节点、名字3部分组成。并且,输入节点,输出节点的名字可以任意指定。
假设之前训练了一个模型,让模型在一组混乱的数据中找到y≈2x的规律。其中
(1)用saved_model模块的builder.SavedModelBuilder类实例化一个builder对象。
(2)构建签名的输入节点inputs。该输入节点的名字为“input_x”。该名字是模型文件中输入节点的名字(可以任意取名)。
(3)构建标签的输出节点outputs。该输出节点的名字为“output”。
(4)调用build_signature_def函数,并将输入节点、输出节点和名字(sig_name)传入,生成一个签名对象。
(5)用builder对象的add_meta_graph_and_variables方法将签名加入到模型中。
(6)调用builder对象的save方法导出带有签名的模型文件。
代码如下:
from tensorflow.python.saved_model import tag_constants
#saveddir+"tfservingmodel"为模型的保存路径
builder = tf.saved_model.builder.SavedModelBuilder(savedir+'tfservingmodel')
#定义输入签名,X为输入tensor
inputs = {'input_x': tf.saved_model.utils.build_tensor_info(X)}
#定义输出签名, z为最终需要的输出结果tensor
outputs = {'output' : tf.saved_model.utils.build_tensor_info(z)}
#调用build_signature_def()函数,并将输入节点、输出节点和名字(sig_name)传入,生成具体的签名对象
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'sig_name')
#将节点的定义和值加到builder中,同时还加入了tag标签(tag_constants.SERVING), 还可以使用TRAINING、GPU或自定义
builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], {'my_signature':signature})
builder.save()
运行后,会生成如下图所示文件
其中variables文件里的内容如下所示
从第一张图可以看出,tfservingmodel文件夹包含了一个文件和一个文件夹,文件save_model.pb是模型的定义文件,文件夹variables中放置了具体的模型文件。
从第二张图可以看出,variables文件夹包含了两个模型文件,variables.data-00000-of-00001文件保存了模型中参数的值,variables.index文件保存了模型中节点符号的定义。
我们可以看下saved_model.pb文件中保存的张量名字和属性
import tensorflow as tf
from tensorflow import saved_model as sm
model_path = "log/tfservingmodel"
with tf.Session()as sess:
meta_graph_def = tf.saved_model.loader.load(sess,[sm.tag_constants.SERVING],model_path)
op_list = sess.graph.get_operations() #load完后可以直接从sess.graph中获取所有节点
with open("operations.txt",'a+')as f:
for index,op in enumerate(op_list):
f.write(str(op.name)+"\n")
f.write(str(op.values())+"\n")
导入刚刚保存的模型
(1)用saved_model模块中的loader.load方法根据tag标签导入对应的模型文件。
(2)用signature_def方法从导入的模型中提取签名。
(3)以字典取值的方式取出输入、输出节点。
(4)向模型注入数据,并输出结果。
代码如下:
from tensorflow.python.saved_model import tag_constants
with tf.Session() as sess:
#根据tag_constants.SERVING标签找到对应的计算图
meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], savedir+'tfservingmodel')
# 从meta_graph_def中取出SignatureDef对象
signature = meta_graph_def.signature_def
# 从signature中找出具体输入输出的tensor name
x = signature['my_signature'].inputs['input_x'].name
result = signature['my_signature'].outputs['output'].name
y = sess.run(result, feed_dict={x: 5})#传入5,进行预测
print(y)
在命令行中,用saved_model_cli工具查看和使用生成的saved_model模型。具体内容如下 :
(1)找出tag标签对应的MetaGraphDef。
(2)找出MetaGraphDef中的signature、输入、输出节点等相关信息。
(3)以命令行的方式向模型输入数据,使其运行并输出结果。
saved_model_cli工具工具共有两个主要的参数:
saved_model_cli show --dir log/tfservingmodel
运行结果:
我们可以看到输出结果为serve,表明SavedModel对象里面只有一个MetaGraphDef,这个serve对应于tag_constants.SERVING。
saved_model_cli show --dir log/tfservingmodel --tag_set serve
运行结果:
我们可以看到输出结果为SignatureDef Key:“my_signature”,表明serve对应的MetaGraphDef中有一个签名为"my_signature",与上面1中生成带有签名时的一致。
saved_model_cli show --dir log/tfservingmodel --tag_set serve --signature_def my_signature
运行结果:
我们可以看到,模型的输入节点的张量为input_x,输出节点的张量为output。
上面的内容可以用saved_model_cli工具中的“–all”参数查看模型文件中的全部信息。
saved_model_cli show --dir log/tfservingmodel --all
用saved_model_cli 工具的run参数时,需要先指定好模型的路径、tag(标签)及signature(签名),再往模型里面输入数据,并运行。
在输入数据部分,可以用参数来指定不同的输入方式。
以“–input_exprs”为例,具体命令如下:
saved_model_cli run --dir log/tfservingmodel --tag_set serve --signature_def my_signature --input_exprs"input_x=4.2"
参考书籍:《深度学习之TensorFlow工程化项目实战》 李金洪 编著