Tensorflow2模型的保存与加载

模型的保存与加载

  • 保存和加载整个模型
      • 模型的保存:
      • 模型的加载
  • 加载模型的参数

通过本篇blog,你将会学到

  • 将所有内容以 TensorFlow SavedModel 格式(或较早的 Keras H5 格式)保存到单个归档。这是标准做法。
  • 仅保存架构/配置,通常保存为 JSON 文件。
  • 仅保存权重值。通常在训练模型时使用。
    参考链接Tensorflow官方

保存和加载整个模型

模型的保存:

API:

model.save() 或 tf.keras.models.save_model()

参考此API,你将保存完整的模型架构、训练权重、优化器及其状态等各种信息

同时模型的保存有两种格式:

  1. TensorFlow SavedModel 格式(推荐使用/默认格式)
  2. Keras H5 格式(较早格式)

您可以通过以下方式切换到 H5 格式。
3. 将 save_format=‘h5’ 传递给 save()。
4. 将以 .h5 或 .keras 结尾的文件名传递给 save()。

模型的加载

API:

tf.keras.models.load_model()

举个例子:
如下所示,加载器动态地创建了一个与原始模型行为类似的新模型。
modelloaded为两个Model,一个是先前的自己写的模型。一个是加载保存后的模型。

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel

loaded = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded(input_arr), outputs)

print("Original model:", model)
print("Loaded model:", loaded)

加载模型的参数

我们日常训练时,可能自己训练了老久的参数不想放弃,下次训练时接着上次的参数训练下去。
因此下面将讲述如何加载模型的参数,来实现接着上次训练结束后的参数训练。(可以节约很多时间)

  1. 模型的保存–参考上面模型的保存
  2. 权重值的加载(未看)
    目前本人采取的方法: 采用if判断,如果有保存的模型,则加载之前保存的模型(即不再通过代码的模型),此时训练的参数一并会被加载。
    如果没有保存的模型,则重新开始训练。
    未完待续……

你可能感兴趣的:(#,Tensorflow2.0,学习专栏,机器学习,深度学习,tensorflow,机器学习)