tensorflow2.x学习笔记十一:tensorflow2.x如何保存和加载模型

一、模型的保存

我们在保存模型的时候既可以保存权重+图结构,也可以只保存权重

1. 保存 权重+图结构

有四种方式,一种是直接调用saved_model下的save API;另一种是添加callbacks参数,后者要注意save_weights_only 参数要设为False;还有一种就是直接使用model.save API;最后一种就是使用tf.keras.models下的save_model API。

tf.saved_model.save,保存的文件格式是.pb,传入的参数是文件所在文件夹的路径

tf.saved_model.save(model, 
                    os.path.join("graph_def_and_weights"))

设置callbacks参数

logdir =os.path.join( 'graph_def_and_weights')
if not os.path.exists(logdir):
       os.mkdir(logdir)
output_model_file = os.path.join(logdir,
                                "fashion_mnist_weights.h5")

callbacks = [
    keras.callbacks.ModelCheckpoint(output_model_file,
                                    save_best_only = True,
                                    save_weights_only = False),
]
history = model.fit(x_train_scaled, y_train, epochs=4,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

model.save

output_model_file = os.path.join(logdir,
                                "fashion_mnist_weights.h5")
model.save(output_model_file )

tf.keras.models.save_model

#model:模型

#filepath:有两种形式,如果保存的是.pb文件,那么该参数就是文件所在文
#件夹路径,比如上面的logdir;如果是.h5文件,那么该参数就是上面的 
#output_model_file 

#overwrite:如果有已经存在的model,是否覆盖它

#include_optimizer:是否将优化器optimizer的状态一起保存到模型中

#save_format:是保存成"tf"格式,还是"h5"格式;在2.x中默认是"tf"
#"tf"格式也就是.pb格式的文件
tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, 
    save_format=None,signatures=None, options=None
)

2. 只保存权重weights

也是有两种方式,一种是直接调用model.save_weights API;另一种就是仍然使用callbacks,但是参数save_weights_only 要设置成True

logdir =os.path.join( 'saved__weights')
output_model_file = os.path.join(logdir,
                                "fashion_mnist_weights_1.h5")
model.save_weights(output_model_file )
logdir =os.path.join( 'graph_def_and_weights')
if not os.path.exists(logdir):
       os.mkdir(logdir)
output_model_file = os.path.join(logdir,
                                "fashion_mnist_weights_2.h5")

callbacks = [
    keras.callbacks.ModelCheckpoint(output_model_file,
                                    save_best_only = True,
                                    save_weights_only = True),
]
history = model.fit(x_train_scaled, y_train, epochs=4,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

二、模型的加载

模型的加载也有需要注意的地方。他主要有三种方式。第一种是直接调用saved_model下的load API,但是这个要注意一点,就是它只能加载.pb文件,并且提供的路径参数是.pb文件所在的文件夹路径;第二种是调用model.load_weights API,它只适合加载只保存了权重的文件;第三种就是调用tf.keras.models.load_model

tf.saved_model.load

loaded_saved_model = 
            tf.saved_model.load('./graph_def_and_weights')

model.load_weights


model.load_weights(output_model_file)

tf.keras.models.load_model,这种方式既可以加载.pb文件,也可以加载.h5文件,如下所示:

tf.keras.models.load_model(
           os.path.join('./graph_def_and_weights'), 
           custom_objects=None, compile=True)

tf.keras.models.load_model(
           os.path.join(output_model_file), 
           custom_objects=None, compile=True)

你可能感兴趣的:(tensorflow,深度学习)