如何使用Tensorflow保存或者加载模型(三) -- Keras

1.背景

Keras是一个非常易于上手,好用的深度学习框架。不仅容易构建模型,而且容易保存模型。目前,Keras已经被纳入到Tensorflow框架中了,因此站长打算在介绍Tensorflow模型保存的时候,可以一并把Keras的模型训练和保存也介绍了。

1.1 模型构建

下面可以以一个简单的demo模型来作为说明,构建模型结构。

# -*- coding: utf-8 -*-
# @Time    : 2019-08-03 17:46
# @Author  : AlexCen
# @Blog    :http://www.alexcen.com/
from __future__ import absolute_import, division, print_function, unicode_literals
from tensorflow import keras
from tensorflow.keras import layers


##1.构建模型##
inputs = keras.Input(shape=(784, ), name='digits')
x = layers.Dense(64, activation='relu', name='d1')(inputs)
x = layers.Dense(64, activation='relu', name='d2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)

model = keras.Model(inputs=inputs, outputs=outputs, name='mlp')
model.summary()

1.2 数据加载和模型训练

加载mnist的默认数据,然后进行训练

# -*- coding: utf-8 -*-
# @Time    : 2019-08-03 17:46
# @Author  : AlexCen
# @Blog    :http://www.alexcen.com/

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss='sparse_categorical_crossentropy',optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train, batch_size=64, epochs=1)

1.3 模型保存和输出

将模型输出并保存。一行代码就完成了。

# -*- coding: utf-8 -*-
# @Time    : 2019-08-03 17:46
# @Author  : AlexCen
# @Blog    :http://www.alexcen.com/

#模型预测
prediction = model.predict(x_test)

##模型保存成h5文件
model.save('tensorflow_model_save_restore/result/keras_model/test_model.h5')

1.4 模型加载和预测

模型的加载也是非常简单,一行代码即可。

##加载模型
model_new = keras.models.load_model('tensorflow_model_save_restore/result/keras_model/test_model.h5')
prediction_new = model_new.predict(x_test)

1.5 结果示例

array([[2.0833343e-06, 5.7367635e-07, 3.0239375e-04, 8.2307472e-04,
        3.1286825e-07, 1.2593745e-05, 4.4974424e-09, 9.9865258e-01,
        1.1882040e-05, 1.9443125e-04],
       [9.6831842e-05, 1.4330833e-05, 9.9635977e-01, 2.7176763e-03,
        7.9741378e-09, 5.3533819e-04, 2.2054941e-04, 1.6272834e-07,
        5.5394452e-05, 7.3264599e-09],
       [7.0375630e-05, 9.7449875e-01, 8.4965276e-03, 2.2081137e-03,
        4.2712223e-04, 1.0506579e-03, 1.0271738e-03, 5.4973103e-03,
        4.8175761e-03, 1.9064705e-03],
       [9.9902606e-01, 7.6533127e-09, 4.4397468e-05, 7.0306480e-05,
        1.6069802e-06, 7.1709284e-05, 2.7319507e-05, 1.3904357e-05,
        3.0369791e-05, 7.1443513e-04],
       [5.1576592e-04, 8.4059438e-06, 1.3240698e-03, 1.4757921e-04,
        8.7131393e-01, 2.3408027e-04, 5.0422893e-04, 3.5719746e-03,
        6.8508921e-04, 1.2169483e-01]], dtype=float32)

2.总结

本文介绍如何通过keras框架来进行构建模型框架和模型保存,可见,keras是非常容易上手的工具,对于基本调用的使用者来说提供了友好的交互形式,不过,如果需要对模型结果进行特殊的调整的话,可能就不太方便了,建议有这类需求的同学还是去Tensorflow来构建模型。

你可能感兴趣的:(如何使用Tensorflow保存或者加载模型(三) -- Keras)