回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。
虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类
keras.callbacks.Callback()
是回调函数的抽象类,定义新的回调函数必须继承自该类
类属性
params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)
model:keras.models.Model对象,为正在训练的模型的引用
回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。
目前,模型的.fit()中有下列参数会被记录到logs中:
在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。
在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数
在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc
from keras.callbacks import Callback
最开始训练过程是先训练一遍,然后得到一个验证集的正确率变化趋势,从而知道最佳的epoch,设置最佳epoch,再训练一遍得到最终结果,十分浪费时间!!!
节省时间的一个办法是在验证集准确率不再上升的时候,终止训练。keras中的callback可以帮我们做到这一点。
callback是一个obj类型的,他可以让模型去拟合,也常在各个点被调用。它存储模型的状态,能够采取措施打断训练,保存模型,加载不同的权重,或者替代模型状态。
callbacks可以用来做这些事情:
monitor为选择的检测指标,这里选择检测’acc’识别率为指标,
filepath为我们存储的位置和模型名称,以.h5为后缀,
import keras
# Callbacks are passed to the model fit the `callbacks` argument in `fit`,
# which takes a list of callbacks. You can pass any number of callbacks.
callbacks_list = [
# This callback will interrupt training when we have stopped improving
keras.callbacks.EarlyStopping(
# This callback will monitor the validation accuracy of the model
monitor='acc',
# Training will be interrupted when the accuracy
# has stopped improving for *more* than 1 epochs (i.e. 2 epochs)
patience=1,
),
# This callback will save the current weights after every epoch
keras.callbacks.ModelCheckpoint(
filepath='my_model.h5', # Path to the destination model file
# The two arguments below mean that we will not overwrite the
# model file unless `val_loss` has improved, which
# allows us to keep the best model every seen during training.
monitor='val_loss',
save_best_only=True,
)
]
# Since we monitor `acc`, it should be part of the metrics of the model.
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
# Note that since the callback will be monitor validation accuracy,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val))
callbacks_list = [
keras.callbacks.ReduceLROnPlateau(
# This callback will monitor the validation loss of the model
monitor='val_loss',
# It will divide the learning by 10 when it gets triggered
factor=0.1,
# It will get triggered after the validation loss has stopped improving
# for at least 10 epochs
patience=10,
)
]# Note that since the callback will be monitor validation loss,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
epochs=30,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val))
keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)
TensorBoard是TensorFlow提供的可视化工具,该回调函数将日志信息写入TensorBorad,使得你可以动态的观察训练和测试指标的图像以及不同层的激活值直方图。
如果已经通过pip安装了TensorFlow,我们可通过下面的命令启动TensorBoard:
tensorboard --logdir=/full_path_to_your_logs
keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None)
用于创建简单的callback的callback类
该callback的匿名函数将会在适当的时候调用,注意,该回调函数假定了一些位置参数:
参数
# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(
on_batch_begin=lambda batch,logs: print(batch))
# Plot the loss after every epoch.
import numpy as np
import matplotlib.pyplot as plt
plot_loss_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),
logs['loss']))
# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = LambdaCallback(
on_train_end=lambda logs: [
p.terminate() for p in processes if p.is_alive()])
model.fit(...,
callbacks=[batch_print_callback,
plot_loss_callback,
cleanup_callback])
如果内置的callback操作还满足不了需求,可以通过继承keras.callbacks.Callback编写自己的回调函数,回调函数通过类成员self.model访问模型,该成员是模型的一个引用。
简单的保存每个batch的loss的回调函数
from keras.callbacks import Callback
class LossHistory(Callback):
def on_train_begin(self, logs={}):
self.losses = []
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss))
将激活值以数组的形式存进磁盘
class ActivationLogger(keras.callbacks.Callback):
def set_model(self, model):
# This method is called by the parent model before training,
# to inform the callback of what model will be calling it
self.model = model
layer_outputs = [layer.output for layer in model.layers]
# This is a model instance that returns the activations of every layer
self.activations_model = keras.models.Model(model.input, layer_outputs)
def on_epoch_end(self, epoch, logs=None):
if self.validation_data is None:
raise RuntimeError('Requires validation_data.')
# Obtain first input sample of the validation data
validation_sample = self.validation_data[0][0:1]
activations = self.activations_model.predict(validation_sample)
# Save arrays to disk
f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
np.savez(f, activations)
f.close()
参考文档:
https://keras-cn.readthedocs.io/en/latest/other/callbacks/