ModelCheckpoint技术

在神经网络的训练学习过程中,常常需要把训练好的模型保存下来,ModelCheckpoint技术就是一种很实用的模型保存与改进方法。

在keras中通过回调API实现Checkpoint功能,本质上是callbacks的一个类。使用前需要从keras库中调用:

from kearas.callbacks import ModelCheckpoint

ModelCheckpoint的一般格式是:

checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', save_weights_only=False, period=5)

filename是保存的文件名(含路径)。
monitor是需要监测的值。
verbose是信息展示模式。
save_best_only, True or False. True, 保存训练集上性能最好的模型。
mode是模型评判准则。
save_weights_only, True or False. True, 只保存模型权重,否则保存整个模型。
period是checkpoint之间间隔的epoch数。

ModelCheckpoint主要有两个功能:
1、利用checkpoint改进模型。

filename = 'improvement-{epoch:02d}-{loss:.2f}.hdf5'
checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', period=5)
model.fit(x_train, y_train, epochs=500, batch_size=100, callbacks=[checkpoint], verbose=1, shuffle=False)

程序执行后,会保存一系列文件,我们可以通过这些文件了解模型的训练过程。
2、利用checkpoint获得最佳模型。

filename = 'improvement-best.hdf5'
checkpoint = ModelCheckpoint(filename, monitor='loss', verbose=1, save_best_only=True, mode='min', period=5)
model.fit(x_train, y_train, epochs=500, batch_size=100, callbacks=[checkpoint], verbose=1, shuffle=False)

程序执行后,文件名相同的文件依次覆盖,最后只有一个文件就是训练效果最好的模型。

那么如何加载一个已经训练好的模型呢?也有两种方法:
1、整模型加载。
如果前面不是只保存了权重的话,在这里是可以加载整个模型的,加载的方法也很简单:

from keras.models import load_model
load_model(filename/path)

2、仅加载权重。
我们也可以只加载权重。

model.load_weights(filename/path)

这样通过ModelCheckpoint技术,我们就可以实现模型的保存与改进。

你可能感兴趣的:(ModelCheckpoint技术)