【时间】2019.12.06
【题目】Keras中如何利用回调函数Callback设置训练前先验证一遍验证集数据
Keras中能够利用keras.callbacks.ModelCheckpoint()回调函数设置保存最佳权重的点,如:
checkpoint = ModelCheckpoint(filepath=best_filepath,monitor="val_acc", verbose=1,
save_best_only=True,mode="max",save_weights_only=True)
但是在第一个epoch时,监视指标为-np.inf (以val_acc为例),这样如果你加载已有weights后进行训练的第一个epoch一定会保存,这不符合初衷。一个解决办法是:训练前先验证一遍验证集数据。但是查看了keras,好像没有实现这个功能的函数。
最后查看keras.callbacks.ModelCheckpoint()的源代码,发现是通过监控checkpoint.best来决定是否保存,这样就可以这样实现:
1、使用Callback的自定义函数在第训练之前先对验证集验证一遍,获得监控指标,如val_acc
2、将监控指标赋值给checkpoint.best。
使用Callback的自定义回调函数在第训练之前先对验证集验证一遍,获得监控指标,并赋值给checkpoint.best
自定义回调函数代码:
##EvaluateBeforeTrain
class EvaluateBeforeTrain(keras.callbacks.Callback):
def __init__(self,checkpoint):
super(Callback,self).__init__()
self.checkpoint=checkpoint
def on_epoch_begin(self,epoch,logs={}):
#evaluate validation data
if epoch==0:
X=self.validation_data[0]
Y=self.validation_data[1]
result=self.model.evaluate(X,Y)
loss=result[0]
acc=result[1]
self.checkpoint.best=acc
print('first_val_acc is:',acc)
分析:
1.在自定义Callback回调函数时,使用self.validation_data可获得模型在fit()中传入的validation_data,使用self.model能够获得对应模型实例。
2、定义on_epoch_begin()函数可以让函数在每个epoch之前执行
PS:不使用on_train_begin()函数的原因是此函数获取不了self.validation_data。(不知原因)
3、传入的参数self.checkpoint=checkpoint是ModelCheckPoint的实例,目的是为了给checkpoint.best赋值
from sklearn.preprocessing import LabelBinarizer
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import ModelCheckpoint,Callback
best_filepath=''
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
trainX = trainX.reshape((trainX.shape[0], 3072))
testX = testX.reshape((testX.shape[0], 3072))
# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)
##DNN
model = Sequential()
model.add(Dense(1024, input_shape=(3072,), activation="relu"))
model.add(Dense(512, activation="relu"))
model.add(Dense(10, activation="softmax"))
if os.path.exists(best_filepath):
model.load_weights(best_filepath)
print("have load weight")
checkpoint = ModelCheckpoint(filepath=best_filepath,monitor="val_acc", verbose=1,
save_best_only=True,mode="max",save_weights_only=True)
evaluateBeforeTrain=EvaluateBeforeTrain(checkpoint)
callbacks_list = [checkpoint,evaluateBeforeTrain]
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=64,callbacks=callbacks_list,verbose=1)#callbacks=callbacks_list,
model.save_weights(final_save_path)
运行结果: