TensorFlow——模型保存

模型保存

五种模型保存方法

  1. 模型整体的保存
  2. 模型框架的保存
  3. 模型权重的保存
  4. 使用回调函数对模型进行保存
  5. 对自定义训练模型的保存

一、模型整体的保存

整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。这样,可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。

在Keras中保存完全可以正常使用的模型非常有用,您可以在TensorFlow.js中加载它们,然后在网络浏览器中训练和运行它们。

Keras使用HDF5标准提供基本的保存格式。
保存模型

# 保存模型,参数为保存路径
model.save('less_model.h5')

使用保存好的模型

# 使用保存好的模型
new_model = tf.keras.models.load_model('less_model.h5')

此方法保存以下内容:

  1. 权重值
  2. 模型配置(架构)
  3. 优化器配置

二、仅保存架构

仅保存模型的架构,不保存权重或优化器。

# 保存模型架构
json_config = model.to_json()

TensorFlow——模型保存_第1张图片
重建模型架构

# 重建模型架构
reinitialized_model = tf.keras.models.model_from_json(json_config)

三、仅保存权重

获取模型权重

# 获取模型权重
weights = model.get_weights()

为模型设置权重,使模型加载训练好的权重

# 为模型设置权重
reinitialized_model.set_weights(weights)

将模型权重保存到磁盘

# 将模型权重保存到磁盘
model.save_weights('less_weights.h5')

从磁盘中加载保存好的权重

# 从磁盘中加载保存好的权重
reinitialized_model.load_weights('less_weights.h5')

四、在训练期间保存检查点

回调函数:tf.keras.callbacks.ModelCheckpoint
保存检查点

# 配置保存检查点的回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint('training_cp/cp.ckpt',save_weights_only=True)

# 训练模型,指定保存检查点的回调函数
model.fit(train_image,train_label,epochs=3,callbacks=[cp_callback])

使用保存好的检查点

# 加载保存好的检查点
model.load_weights('training_cp/cp.ckpt')

五、自定义训练中保存检查点

保存检查点

# 指定检查点保存位置
cp_dir = './customtrain_cp'
cp_prefix = os.path.join(cp_dir,'ckpt')

# 定义检查点,保存优化器和模型架构
checkpoint = tf.train.Checkpoint(optimizer=optimizer,model=model)
def train():
    for epoch in range(5):
        for (batch,(images,labels)) in enumerate(dataset):
            train_step(model,images,labels)
        print('Epoch{} loss is {}'.format(epoch,train_loss.result()))
        print('Epoch{} Accuracy is {}'.format(epoch,train_accuracy.result()))
        train_loss.reset_states()
        train_accuracy.reset_states()
        if(epoch + 1)%2 == 0:
            # 每两个epoch保存检查点
            checkpoint.save(file_prefix = cp_prefix)

恢复检查点

# 取出检查点目录下最新的检查点
tf.train.latest_checkpoint(cp_dir)

# 从指定检查点恢复检查点
checkpoint.restore(tf.train.latest_checkpoint(cp_dir))

你可能感兴趣的:(#,TensorFlow,深度学习,tensorflow,python)