keras 回调函数Callbacks 断点ModelCheckpoint

整理自keras:https://keras-cn.readthedocs.io/en/latest/other/callbacks/

参考:https://blog.csdn.net/li_haiyu/article/details/80894299

回调函数Callbacks

回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。

Callback

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

类属性

  • params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数) 

  • model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

  • 在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,accloss,如果指定了验证集,还会包含验证集正确率和误差val_acc)val_lossval_acc还额外需要在.compile中启用metrics=['accuracy']

  • 在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

  • 在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

ModelCheckpoint

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

该回调函数将在每个epoch后保存模型到filepath 

filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。

参数:

  • filepath: 字符串,保存模型的路径。
  • monitor: 被监测的数据。val_acc或这val_loss
  • verbose: 详细信息模式,0 或者 1 。0为不打印输出信息,1打印
  • save_best_only: 如果 save_best_only=True, 将只保存在验证集上性能最好的模型
  • mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
  • save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
  • period: 每个检查点之间的间隔(训练轮数)。

代码实现过程:

① 从keras.callbacks导入ModelCheckpoint类

from keras.callbacks import ModelCheckpoint

②  在训练阶段的model.compile之后加入下列代码实现每一次epoch(period=1)保存最好的参数

checkpoint = ModelCheckpoint(filepath,
    monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)

③ 在训练阶段的model.fit之前加载先前保存的参数

if os.path.exists(filepath):
    model.load_weights(filepath)
    # 若成功加载前面保存的参数,输出下列信息
    print("checkpoint_loaded")

④ 在model.fit添加callbacks=[checkpoint]实现回调

model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
        steps_per_epoch=max(1, num_train//batch_size),
        validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
        validation_steps=max(1, num_val//batch_size),
        epochs=3,
        initial_epoch=0,
        callbacks=[checkpoint])

你可能感兴趣的:(keras)