Keras回调函数Callbacks使用详解及训练过程可视化

Keras回调函数Callbacks使用详解及训练可视化

  • 介绍
  • 功能
    • History(训练可视化)
    • EarlyStopping
    • ModelCheckpoint
    • ReduceLROnPlateau
    • CSVLogger

介绍

内容参考了keras中文文档
回调函数Callbacks
回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。

【Tips】虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

类属性:

  • params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

  • model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

  • 在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。

  • 在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

  • 在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

from keras.callbacks import Callback

功能

History(训练可视化)

keras.callbacks.History()

该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图,代码样例如下:

history=model.fit(X_train, Y_train, validation_data=(X_test,Y_test),
		batch_size=16, epochs=20)
##或者
#history=model.fit(X_train,y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test))		
fig1, ax_acc = plt.subplots()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Model - Accuracy')
plt.legend(['Training', 'Validation'], loc='lower right')
plt.show()

fig2, ax_loss = plt.subplots()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model- Loss')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()

EarlyStopping

keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')

当监测值不再改善时,该回调函数将中止训练
参数

  • monitor:需要监视的量

  • patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

  • verbose:信息展示模式
    verbose = 0 为不在标准输出流输出日志信息
    verbose = 1 为输出进度条记录
    verbose = 2 为每个epoch输出一行记录
    默认为 1

  • mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

ModelCheckpoint

该回调函数将在每个epoch后保存模型到filepath

filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_end的logs关键字所填入

例如,filepath若为weights.{epoch:02d-{val_loss:.2f}}.hdf5,则会生成对应epoch和验证集loss的多个文件。

参数

  • filename:字符串,保存模型的路径

  • monitor:需要监视的值

  • verbose:信息展示模式,0或1

  • save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

  • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

  • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

  • period:CheckPoint之间的间隔的epoch数

Callbacks中可以同时使用多个以上两个功能,举例如下

callbacks = [EarlyStopping(monitor='val_loss', patience=8),
             ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]
history=model.fit(X_train, y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test))

在样例中,EarlyStopping设置衡量标注为val_loss,如果其连续4次没有下降就提前停止 ,ModelCheckpoint设置衡量标准为val_loss,设置只保存最佳模型,保存路径为best——model.h5

ReduceLROnPlateau

keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)

当评价指标不在提升时,减少学习率

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。该回调函数检测指标的情况,如果在patience个epoch中看不到模型性能提升,则减少学习率
参数

  • monitor:被监测的量 factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少
  • patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
  • mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。
  • epsilon:阈值,用来确定是否进入检测值的“平原区”
  • cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作 min_lr:学习率的下限
    使用样例如下:
callbacks_test = [
  	keras.callbacks.ReduceLROnPlateau(
 	 #以val_loss作为衡量标准
 	 monitor='val_loss',
  	 # 学习率乘以factor
  	 factor=0.1,
 	 # It will get triggered after the validation loss has stopped improving
	 # 当被检测的衡量标准经过几次没有改善后就减小学习率
 	 patience=10,
    )
	]
	model.fit(x, y,epochs=20,batch_size=16,
  			callbacks=callbacks_test,
 			validation_data=(x_val, y_val))

CSVLogger

keras.callbacks.CSVLogger(filename, separator=’,’, append=False)
将epoch的训练结果保存在csv文件中,支持所有可被转换为string的值,包括1D的可迭代数值如np.ndarray.

参数

  • fiename:保存的csv文件名,如run/log.csv
  • separator:字符串,csv分隔符
  • append:默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件

你可能感兴趣的:(深度学习)