我们在保存模型的时候既可以保存权重+图结构,也可以只保存权重
1. 保存 权重+图结构:
有四种方式,一种是直接调用saved_model下的save API;另一种是添加callbacks参数,后者要注意save_weights_only 参数要设为False;还有一种就是直接使用model.save API;最后一种就是使用tf.keras.models下的save_model API。
tf.saved_model.save,保存的文件格式是.pb,传入的参数是文件所在文件夹的路径
tf.saved_model.save(model,
os.path.join("graph_def_and_weights"))
设置callbacks参数
logdir =os.path.join( 'graph_def_and_weights')
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_weights.h5")
callbacks = [
keras.callbacks.ModelCheckpoint(output_model_file,
save_best_only = True,
save_weights_only = False),
]
history = model.fit(x_train_scaled, y_train, epochs=4,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
model.save
output_model_file = os.path.join(logdir,
"fashion_mnist_weights.h5")
model.save(output_model_file )
tf.keras.models.save_model
#model:模型
#filepath:有两种形式,如果保存的是.pb文件,那么该参数就是文件所在文
#件夹路径,比如上面的logdir;如果是.h5文件,那么该参数就是上面的
#output_model_file
#overwrite:如果有已经存在的model,是否覆盖它
#include_optimizer:是否将优化器optimizer的状态一起保存到模型中
#save_format:是保存成"tf"格式,还是"h5"格式;在2.x中默认是"tf"
#"tf"格式也就是.pb格式的文件
tf.keras.models.save_model(
model, filepath, overwrite=True, include_optimizer=True,
save_format=None,signatures=None, options=None
)
2. 只保存权重weights:
也是有两种方式,一种是直接调用model.save_weights API;另一种就是仍然使用callbacks,但是参数save_weights_only 要设置成True
logdir =os.path.join( 'saved__weights')
output_model_file = os.path.join(logdir,
"fashion_mnist_weights_1.h5")
model.save_weights(output_model_file )
logdir =os.path.join( 'graph_def_and_weights')
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_weights_2.h5")
callbacks = [
keras.callbacks.ModelCheckpoint(output_model_file,
save_best_only = True,
save_weights_only = True),
]
history = model.fit(x_train_scaled, y_train, epochs=4,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
模型的加载也有需要注意的地方。他主要有三种方式。第一种是直接调用saved_model下的load API,但是这个要注意一点,就是它只能加载.pb文件,并且提供的路径参数是.pb文件所在的文件夹路径;第二种是调用model.load_weights API,它只适合加载只保存了权重的文件;第三种就是调用tf.keras.models.load_model
tf.saved_model.load
loaded_saved_model =
tf.saved_model.load('./graph_def_and_weights')
model.load_weights
model.load_weights(output_model_file)
tf.keras.models.load_model,这种方式既可以加载.pb文件,也可以加载.h5文件,如下所示:
tf.keras.models.load_model(
os.path.join('./graph_def_and_weights'),
custom_objects=None, compile=True)
tf.keras.models.load_model(
os.path.join(output_model_file),
custom_objects=None, compile=True)