TensorFlow 2.0 - tf.saved_model.save 模型导出
使用 SavedModel 格式
TensorFlow 模型导出
为了将训练好的机器学习模型部署到各个目标平台(如服务器、移动端、嵌入式设备和浏览器等),我们的第一步往往是将训练好的整个模型完整导出(序列化)为一系列标准格式的文件。在此基础上,我们才可以在不同的平台上使用相对应的部署工具来部署模型文件。
TensorFlow 提供了统一模型导出格式 SavedModel,使得我们训练好的模型可以以这一格式为中介,在多种不同平台上部署,这是我们在 TensorFlow 2 中主要使用的导出格式。
SavedModel格式,包含了一个 TensorFlow 程序的完整信息:不仅包含参数的权值,还包含计算的流程(即计算图)。**当模型导出为 SavedModel 文件时,无须模型的源代码即可再次运行模型,这使得 SavedModel 尤其适用于模型的分享和部署。**TensorFlow Serving(服务器端部署模型)、TensorFlow Lite(移动端部署模型)以及 TensorFlow.js 都会用到这一格式。
磁盘上的 SavedModel 格式
SavedModel 是一个包含序列化签名和运行这些签名所需的状态的目录,其中包括变量值和词汇表。
.
├── efficientdet-d0_frozen.pb
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
saved_model.pb
文件用于存储实际 TensorFlow 程序或模型,以及一组已命名的签名——每个签名标识一个接受张量输入和产生张量输出的函数。variables
目录包含一个标准训练检查点。assets
目录包含 TensorFlow 计算图使用的文件,例如,用于初始化词汇表的文本文件。本例中没有使用这种文件。tf.saved_model.save
模型导出,即模型保存
# 导出模型, 模型目录
tf.saved_model.save(mymodel, "保存的目标文件夹名称")
tf.saved_model.load
模型导入,即模型加载
# 载入模型
mymodel = tf.saved_model.load("保存的目标文件夹名称")
class CustomModule(tf.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.v = tf.Variable(1.)
@tf.function
def __call__(self, x):
print('Tracing with', x)
return x * self.v
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def mutate(self, new_v):
self.v.assign(new_v)
module = CustomModule()
当保存 tf.Module
时,任何 tf.Variable
特性、tf.function
装饰的方法以及通过递归遍历找到的 tf.Module
都会得到保存。所有 Python 特性、函数和数据都会丢失。也就是说,当保存 tf.function
时,不会保存 Python 代码。
简单地说,tf.function
的工作原理是,通过跟踪 Python 代码来生成 ConcreteFunction(一个可调用的 tf.Graph
包装器)。当保存 tf.function
时,实际上保存的是 tf.function
的 ConcreteFunction 缓存。
module_no_signatures_path = os.path.join(tmpdir, 'module_no_signatures')
module(tf.constant(0.))
print('Saving model...')
tf.saved_model.save(module, module_no_signatures_path)
# 输出
Tracing with Tensor("x:0", shape=(), dtype=float32)
Saving model...
Tracing with Tensor("x:0", shape=(), dtype=float32)
INFO:tensorflow:Assets written to: /tmp/tmp0ya0f8kf/module_no_signatures/assets
在 Python 中加载 SavedModel 时,所有 tf.Variable
特性、tf.function
装饰方法和 tf.Module
都会按照与原始保存的 tf.Module
相同对象结构进行恢复。
imported = tf.saved_model.load(module_no_signatures_path)
assert imported(tf.constant(3.)).numpy() == 3
imported.mutate(tf.constant(2.))
assert imported(tf.constant(3.)).numpy() == 6
由于没有保存 Python 代码,所以使用新输入签名调用 tf.function
会失败:
# 输出
imported(tf.constant([3.]))
ValueError: Could not find matching function to call for canonicalized inputs ((,), {}). Only existing signatures are [((TensorSpec(shape=(), dtype=tf.float32, name=u'x'),), {})].
可以使用 tf.saved_model.load
将 SavedModel 加载回 Python。
loaded = tf.saved_model.load(mobilenet_save_path)
print(list(loaded.signatures.keys())) # ["serving_default"]
# 输出
["serving_default"]
导入的签名总是会返回字典。
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)
# 输出
{'predictions': TensorSpec(shape=(None, 1000), dtype=tf.float32, name='predictions')}
tf.keras.Model
会自动指定服务上线签名,但是,对于自定义模块,我们必须明确声明服务上线签名。
当指定单个签名时,签名键为 'serving_default'
,并将保存为常量 tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
。
默认情况下,自定义 tf.Module
中不会声明签名。
assert len(imported.signatures) == 0
如果导出时指定签名,需要使用 signatures
关键字参数指定 ConcreteFunction。
module_with_signature_path = os.path.join(tmpdir, 'module_with_signature')
call = module.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))
tf.saved_model.save(module, module_with_signature_path, signatures=call)
# 输出
Tracing with Tensor("x:0", dtype=float32)
Tracing with Tensor("x:0", dtype=float32)
INFO:tensorflow:Assets written to: /tmp/tmp0ya0f8kf/module_with_signature/assets
imported_with_signatures = tf.saved_model.load(module_with_signature_path)
list(imported_with_signatures.signatures.keys())
# 输出
['serving_default']
要导出多个签名,请将签名键的字典传递给 ConcreteFunction。每个签名键对应一个 ConcreteFunction。
module_multiple_signatures_path = os.path.join(tmpdir, 'module_with_multiple_signatures')
signatures = {"serving_default": call,
"array_input": module.__call__.get_concrete_function(tf.TensorSpec([None], tf.float32))}
tf.saved_model.save(module, module_multiple_signatures_path, signatures=signatures)
# 输出
Tracing with Tensor("x:0", shape=(None,), dtype=float32)
Tracing with Tensor("x:0", shape=(None,), dtype=float32)
INFO:tensorflow:Assets written to: /tmp/tmp0ya0f8kf/module_with_multiple_signatures/assets
imported_with_multiple_signatures = tf.saved_model.load(module_multiple_signatures_path)
list(imported_with_multiple_signatures.signatures.keys())
# 输出
['serving_default', 'array_input']
默认情况下,输出张量名称非常通用,如 output_0
。为了控制输出的名称,请修改 tf.function
,以便返回将输出名称映射到输出的字典。输入的名称来自 Python 函数参数名称。
class CustomModuleWithOutputName(tf.Module):
def __init__(self):
super(CustomModuleWithOutputName, self).__init__()
self.v = tf.Variable(1.)
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def __call__(self, x):
return {'custom_output_name': x * self.v}
module_output = CustomModuleWithOutputName()
call_output = module_output.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))
module_output_path = os.path.join(tmpdir, 'module_with_output_name')
tf.saved_model.save(module_output, module_output_path,
signatures={'serving_default': call_output})
# 输出
INFO:tensorflow:Assets written to: /tmp/tmp0ya0f8kf/module_with_output_name/assets
imported_with_output_name = tf.saved_model.load(module_output_path)
imported_with_output_name.signatures['serving_default'].structured_outputs
# 输出
{'custom_output_name': TensorSpec(shape=(), dtype=tf.float32, name='custom_output_name')}
# 模型导出
model.save('catdog.h5')
# 模型载入
model = tf.keras.models.load_model('catdog.h5')
"""
因为 SavedModel 基于计算图,所以对于使用继承 tf.keras.Model 类建立的 Keras 模型,其需要导出到 SavedModel 格式的方法(比如 call )都需要使用 @tf.function 修饰。
"""
class MLPmodel(tf.keras.Model):
def __init__(self):
super().__init__()
# 除第一维以外的维度展平
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')
self.dense2 = tf.keras.layers.Dense(units=10)
@tf.function # 计算图模式,导出模型,必须写
def call(self, input):
x = self.flatten(input)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
# 导出模型, 模型目录
tf.saved_model.save(mymodel, "./my_model_path")
# 载入模型
mymodel = tf.saved_model.load('./my_model_path')
# 对于使用继承 tf.keras.Model 类建立的 Keras 模型 model ,使用 SavedModel 载入后将无法使用 evaluate,predict 直接进行推理,而需要使用 model.call() 。
res = mymodel.call(data_loader.test_data)
print(res)