TensorFlow学习——tf.keras接口保存和加载模型

跑了几个深度学习的程序之后,开始静下心来研究算法和代码,因为刚开始接触TensorFlow,第一个拦路虎就是各种API的使用,不想囫囵吞枣,只能在工作之外去学习这些基础知识,各个击破。

今天主要是tf.keras接口的简介,以及最近遇到比较多的模型的保存和加载。

1. tf.keras简介

tf.keras接口是在TensorFlow中封装的Keras接口。Keras接口是一个用Python语言编写的高层神经网路API,可以运行在TensorFlow、CNTK以及Theano上,也就是说它可以以这三个框架作为后台。Keras有如下优点:

(1) 对用户友好,模块化,可扩展;

(2) 可支持卷积神经网络和循环神经网络;

(3) 能够无缝支持CPU和GPU。

TensorFlow中的Keras接口对Keras的实现很全面,只要安装TensorFlow就可以使用Keras的所有接口了。

tf.keras接口包含了很多成熟模型的源代码,包括DenseNet、NASNet、MobileNet等,我们可以很方便地使用这些源代码对自己地数据进行训练;也可以加载训练过的模型文件,使用文件里的参数值对模型源代码中的权重进行赋值,赋值后的模型可以直接用来预测。

GitHub上的Keras页面有很多模型的源码、结构、参数等,需要做模型评估的同学可以通过如下链接查看,比如我们前一段时间需要了解模型的参数量、乘加数等,这个东西就特别有参考意义。

GitHub - keras-team/keras-applications: Reference implementations of popular deep learning models.

2. 模型的保存和加载

模型的保存可以使用如下API:

model.save() 
#或 
tf.keras.models.save_model()

模型的加载:

tf.keras.models.load_model()

模型保存格式可以是TensorFlow的SavedModel格式,或者Keras H5格式。SavedModel格式是model.save()的默认格式。若想保存为H5格式,有两种方法可以实现:

(1) 在调用save()函数时,将save_format='h5'传递进去;

(2) 把以.h5或.keras结尾的文件名传递给save()函数。

2.1 SavedModel模型

SavedModel示例如下,该示例来源于官方教程。

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)

 调用 model.save('my_model') 会创建一个名为 my_model 的文件夹,其包含以下内容:

TensorFlow学习——tf.keras接口保存和加载模型_第1张图片

模型架构和训练配置(包括优化器、损失和指标)存储在 saved_model.pb 中,权重保存在 variables/ 目录下。

 2.2 H5模型

Keras 还支持保存单个 HDF5 文件,其中包含模型的架构、权重值和 compile() 信息。它是 SavedModel 的轻量化替代选择。

官方教程示例如下:

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

# Let's check:
np.testing.assert_allclose(
    model.predict(test_input), reconstructed_model.predict(test_input)
)

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
reconstructed_model.fit(test_input, test_target)

 保存出来的文件就是一个单独的my_h5_model.h5文件。

时间关系,今天先写到这里,后面会再根据学习情况更新更多内容。

你可能感兴趣的:(深度学习,TensorFlow,tensorflow,keras,深度学习)