【深度学习 走进tensorflow2.0】训练的模型保存方式

无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。点这里可以跳转到教程。人工智能教程

在我们训练深度学习网络的时候,如何保存模型,并提供给客户端使用,是常见的问题。保存Tensorflow的模型有很多方法-具体而言是您使用的API。本指南使用tf.keras。

方式1:在训练期间保存模型(以checkpoints形式保存)

您可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。tf.keras.callbacks.ModelCheckpoint允许在训练的过程中和结束时保存的模型。

简单:


checkpoint_path = "./model/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1)

# 使用新的回调训练模型
model.fit(x_train, y_train,  epochs=10,validation_data=(x_test,y_test),callbacks=[cp_callback])

# 加载权重
model.load_weights(checkpoint_path)

# 重新评估模型
loss,acc = model.evaluate(x_test,  y_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

复杂:

# 在文件名中包含 epoch (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 创建一个回调,每 5 个 epochs 保存模型的权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, verbose=1, save_weights_only=True,period=5)


# 使用 `checkpoint_path` 格式保存权重
model.save_weights(checkpoint_path.format(epoch=0))


# 使用新的回调*训练*模型
model.fit(x_train, y_train,epochs=10, callbacks=[cp_callback],validation_data=(x_test,y_test),verbose=1)


# 加载以前保存的权重
model.load_weights(latest)

# 重新评估模型
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

【深度学习 走进tensorflow2.0】训练的模型保存方式_第1张图片

这些文件是什么?
上述代码将权重存储到检查点 -格式化文件的集合中,这些文件仅包含二进制格式的训练权重。检查点包含:*一个或多个包含模型权重的分片。*索引文件,指示其中权重存储在那个分片中。如果你只在一台机器上训练一个模型,你将有一个带有后缀的碎片: .data-00000-of-00001

方式2:保存整个模型,将模型保存为HDF5文件

# 创建一个新的模型实例
model = create_model()

# 训练模型
model.fit(train_images, train_labels, epochs=5)

# 将整个模型保存为HDF5文件
model.save('my_model.h5')

# 重新创建完全相同的模型,包括其权重和优化程序
new_model = keras.models.load_model('my_model.h5')

# 显示网络结构
new_model.summary()

loss, acc = new_model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

方式3:以checkpoints形式自动保存最佳模型h5格式


model_name = 'model_ex-{epoch:03d}_acc-{val_accuracy:03f}.h5'

trained_model_dir='./model/'
model_path = os.path.join(trained_model_dir, model_name)


checkpoint = tf.keras.callbacks.ModelCheckpoint(
             filepath=model_path,
             monitor='val_accuracy',
            verbose=1,
            save_weights_only=True,
            save_best_only=True,
            mode='max',
            period=1)


model.fit(x_train, y_train,epochs=10, callbacks=[cp_callback],validation_data=(x_test,y_test),verbose=1)

你可能感兴趣的:(【深度学习 走进tensorflow2.0】训练的模型保存方式)