墙裂推荐:https://cloud.tencent.com/developer/article/1049579
英文版原文:https://machinelearningmastery.com/check-point-deep-learning-models-keras/
keras文档回调函数:http://keras-cn.readthedocs.io/en/latest/other/callbacks/#modelcheckpoint
先看一下ModelCheckpoint的参数:
keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1 ) 1. filename:字符串,保存模型的路径 2. monitor:需要监视的值,val_acc或这val_loss 3. verbose:信息展示模式,0为不打印输出信息,1打印 4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型 5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。 6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等) 7. period:CheckPoint之间的间隔的epoch数
① 从keras.callbacks导入ModelCheckpoint类
from keras.callbacks import ModelCheckpoint
② 在训练阶段的model.compile之后加入下列代码实现每一次epoch(period=1)保存最好的参数
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)提醒:filepath为保存参数的路径,我这里是"logs/000/trained_best_weights.h5"
③ 在训练阶段的model.fit之前加载先前保存的参数
if os.path.exists(filepath): model.load_weights(filepath) # 若成功加载前面保存的参数,输出下列信息 print("checkpoint_loaded")
④ 在model.fit添加callbacks=[checkpoint]实现回调
model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes), steps_per_epoch=max(1, num_train//batch_size), validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes), validation_steps=max(1, num_val//batch_size), epochs=3, initial_epoch=0, callbacks=[checkpoint])
① 第一次输出,没有参数可以加载,不会打印“checkpoint_loaded”,输出如下(测试epoch=3)
② 再次执行train.py,利用刚才的代码可以直接在model.fit之前加载保存前一次训练的参数,继续训练(loss的变化)。注意输出了“checkpoint_load”表示成功加载前面保存的参数
提示:参考链接中有简单的测试代码,以上仅在我的训练数据上的所做的测试,更多详细内容请阅读参考链接
The end.