keras ModelCheckpoint 实现断点续训功能

参考链接:

墙裂推荐:https://cloud.tencent.com/developer/article/1049579

英文版原文:https://machinelearningmastery.com/check-point-deep-learning-models-keras/

keras文档回调函数:http://keras-cn.readthedocs.io/en/latest/other/callbacks/#modelcheckpoint

先看一下ModelCheckpoint的参数:

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    period=1
)

1. filename:字符串,保存模型的路径
2. monitor:需要监视的值,val_acc或这val_loss
3. verbose:信息展示模式,0为不打印输出信息,1打印
4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
5. mode‘auto’minmax之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
7. periodCheckPoint之间的间隔的epoch

代码实现过程:

① 从keras.callbacks导入ModelCheckpoint类

from keras.callbacks import ModelCheckpoint

②  在训练阶段的model.compile之后加入下列代码实现每一次epoch(period=1)保存最好的参数

checkpoint = ModelCheckpoint(filepath,
    monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)
提醒:filepath为保存参数的路径,我这里是"logs/000/trained_best_weights.h5"

③ 在训练阶段的model.fit之前加载先前保存的参数

if os.path.exists(filepath):
    model.load_weights(filepath)
    # 若成功加载前面保存的参数,输出下列信息
    print("checkpoint_loaded")

④ 在model.fit添加callbacks=[checkpoint]实现回调

model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
        steps_per_epoch=max(1, num_train//batch_size),
        validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
        validation_steps=max(1, num_val//batch_size),
        epochs=3,
        initial_epoch=0,
        callbacks=[checkpoint])

测试输出:

① 第一次输出,没有参数可以加载,不会打印“checkpoint_loaded”,输出如下(测试epoch=3)

keras ModelCheckpoint 实现断点续训功能_第1张图片

keras ModelCheckpoint 实现断点续训功能_第2张图片

② 再次执行train.py,利用刚才的代码可以直接在model.fit之前加载保存前一次训练的参数,继续训练(loss的变化)。注意输出了“checkpoint_load”表示成功加载前面保存的参数

keras ModelCheckpoint 实现断点续训功能_第3张图片

 
  

keras ModelCheckpoint 实现断点续训功能_第4张图片

提示:参考链接中有简单的测试代码,以上仅在我的训练数据上的所做的测试,更多详细内容请阅读参考链接

The end.

你可能感兴趣的:(Easy,Keras)