TF2.x的keras模型保存与加载

TF2.x的keras模型保存与加载

  • Keras模型的保存与读取
    • 整个模型的保存与加载
      • APIs
      • 使用`model.save()`或者`tf.keras.models.save_model()`
        • `SavedModel`格式
        • `Keras H5`格式
    • 保存模型结构
      • APIs
      • `Sequential`模型
      • `Functional `模型
    • 保存模型权重
      • APIs
  • 奇怪的现象

传送门:官方文档

Keras模型包含多个组件:

  • 模型的结构或者配置文件,表明模型包含哪些网络层以及各层之间的连接方式。
  • 当前状态模型的参数。
  • 模型的optimizer,在complie里进行定义的。
  • 模型的损失函数和度量函数(在complile函数中定义的或者通过add_loss()add_metric()函数添加)。

通过Keras的API可以将上述的所有组件保存成一个文件或者选择性的保存其中某些组件:

  • Tensorflow SavedModel格式或者Keras H5格式将整个模型保存为一个文件
  • JSON文件形式保存模型的结构或者配置
  • 只保留模型的权重,通常在训练模型的过程中使用。

Keras模型的保存与读取

保存 加载
model.save() tf.keras.models.load_model()
tf.keras.models.save_model() tf.keras.models.load_model()
model.save_weights() model.load_weights()
tf.saved_model.save() tf.saved_model.load()

整个模型的保存与加载

  • 模型的结构和配置
  • 通过训练学习到的权重
  • 模型的编译信息(如果保存前有调用model.compile

APIs

  • model.save()或者tf.keras.models.save_model()
  • tf.keras.models.load_model()

使用model.save()或者tf.keras.models.save_model()

此种方式可以Keras H5格式或者Tensorflow SavedModel格式保存整个模型,在TF2.x版本中默认以SavedModel格式保存,如果想要使用Keras H5格式,可以通过以下形式进行保存:

  • model.save()函数中传递参数saved_format='h5'
  • model.save()函数传递文件名参数时以.h5或者.keras结尾。

SavedModel格式

import tensorflow as tf

model = tf.keras.applications.ResNet50()

# 保存模型, 创建`saved_model`文件夹并以`SavedModel`格式保存模型
model.save("saved_model")
# `saved_model`文件夹包含以下文件(tf2.3没有metadata文件):
# assets  keras_metadata.pb  saved_model.pb  variables

# 加载模型
model = tf.keras.models.load_model('saved_model')

Keras H5格式

import tensorflow as tf

model = tf.keras.applications.ResNet50()

# 以h5的形式保存整个模型,如果只是想部署,可以设置include_optimizer=False,减少模型体积
model.save("my_h5_model.hdf5", include_optimizer=False, save_format="h5")

# 加载模型
model = tf.keras.models.load_model("my_h5_model.h5")

保存模型结构

模型的配置或者结构表明模型包含哪些网络层,以及这些网络层的连接方式,通过模型的配置文件或者结构文件可以生成一个新的具有初始权重的模型,但不包含编译信息,例如损失函数、度量函数或者优化器函数。

APIs

  • get_config()from_config()
  • tf.keras.models.model_to_json()tf.keras.models.model_from_json()

调用config=model.get_config()将会以Python Dict的形式返回一个模型的配置,通过调用Sequential.from_config(config)(Sequential模型)或者Model.from_config(config)(Functional API模型)生成一个新的模型。

Sequential模型

import tensorflow as tf

model = tf.keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])

# 以config形式保存与加载
config = model.get_config()
new_model = tf.keras.Sequential.from_config(config)

# 以json形式保存与加载
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

Functional模型

import tensorflow as tf

inputs = tf.keras.Input((32,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, outputs)

# 以config形式保存与加载
config = model.get_config()
new_model = tf.keras.Model.from_config(config)

# 以json形式保存与加载
json_config = model.to_json()
new_model = keras.models.model_from_json(json_config)

保存模型权重

适合只使用模型进行推理后者进行迁移学习,可保存为Tensorflow Checkpoint或者Keras H5格式,默认格式为H5

APIs

  • model.save_weights()
import tensorflow as tf

model = tf.keras.applications.ResNet50()
# 保存成h5格式
model.save_weights("weights.h5", save_format="h5")
# 加载h5格式权重
model.load_weights("weights.h5") 

# 保存为tf.train.Checkpoint格式
model.save_weights("model.ckpt",  save_format="tf")
# 加载Checkpoint格式权重
model.load_weights("model.ckpt") 

奇怪的现象

keras模型保存成SavedModel格式,虽可以使用keras.models.load_model或者tf.saved_model.load进行加载,但都会报类似下面的warning,并且会严重拖慢模型加载的速度,但h5格式的模型通过keras.models.load_model不会出现此类问题,很奇怪。

WARNING:tensorflow:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_xxxxx) with ops with custom gradients. Will likely fail if a gradient is requested.

你可能感兴趣的:(tensorflow,keras,tensorflow,深度学习)