整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。这样,可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
在Keras中保存完全可以正常使用的模型非常有用,您可以在TensorFlow.js中加载它们,然后在网络浏览器中训练和运行它们。
Keras使用HDF5标准提供基本的保存格式。
保存模型
# 保存模型,参数为保存路径
model.save('less_model.h5')
使用保存好的模型
# 使用保存好的模型
new_model = tf.keras.models.load_model('less_model.h5')
此方法保存以下内容:
仅保存模型的架构,不保存权重或优化器。
# 保存模型架构
json_config = model.to_json()
# 重建模型架构
reinitialized_model = tf.keras.models.model_from_json(json_config)
获取模型权重
# 获取模型权重
weights = model.get_weights()
为模型设置权重,使模型加载训练好的权重
# 为模型设置权重
reinitialized_model.set_weights(weights)
将模型权重保存到磁盘
# 将模型权重保存到磁盘
model.save_weights('less_weights.h5')
从磁盘中加载保存好的权重
# 从磁盘中加载保存好的权重
reinitialized_model.load_weights('less_weights.h5')
回调函数:tf.keras.callbacks.ModelCheckpoint
保存检查点
# 配置保存检查点的回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint('training_cp/cp.ckpt',save_weights_only=True)
# 训练模型,指定保存检查点的回调函数
model.fit(train_image,train_label,epochs=3,callbacks=[cp_callback])
使用保存好的检查点
# 加载保存好的检查点
model.load_weights('training_cp/cp.ckpt')
保存检查点
# 指定检查点保存位置
cp_dir = './customtrain_cp'
cp_prefix = os.path.join(cp_dir,'ckpt')
# 定义检查点,保存优化器和模型架构
checkpoint = tf.train.Checkpoint(optimizer=optimizer,model=model)
def train():
for epoch in range(5):
for (batch,(images,labels)) in enumerate(dataset):
train_step(model,images,labels)
print('Epoch{} loss is {}'.format(epoch,train_loss.result()))
print('Epoch{} Accuracy is {}'.format(epoch,train_accuracy.result()))
train_loss.reset_states()
train_accuracy.reset_states()
if(epoch + 1)%2 == 0:
# 每两个epoch保存检查点
checkpoint.save(file_prefix = cp_prefix)
恢复检查点
# 取出检查点目录下最新的检查点
tf.train.latest_checkpoint(cp_dir)
# 从指定检查点恢复检查点
checkpoint.restore(tf.train.latest_checkpoint(cp_dir))