目录
早停机制EarlyStopping
原理
参数
回调函数ReduceLROnPlateau
原理
参数
具体使用
导入
定义参数
训练模型
将数据分为训练集和验证集,每个epoch结束后(或每N个epoch后): 在验证集上获取测试结果,随着epoch的增加,如果在验证集上发现测试误差上升,则停止训练;将停止之后的权重作为网络的最终参数。
这种做法很符合直观感受,因为精度都不再提高了,在继续训练也是无益的,只会提高训练的时间。在训练的过程中,记录到目前为止最好的验证集精度,当连续n次Epoch(参数设置)没达到最佳精度时,则可以认为精度不再提高了。
monitor:
监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。
min_delta:
增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了你的容忍程度。如monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。
patience:
能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。
mode:
就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。
restore_best_weights :
如果restore_best_weights默认为False,如果是False,则保留最后一次训练时的权重参数,如果设置为True,则保存训练过程中准确率最高或者误差最时的网络权重。
verbose:
日志显示函数,verbose = 0 为不在标准输出流输出日志信息,verbose = 1 为输出进度条记录,verbose = 2 为每一个epoch输出一行记录
定义学习率之后,经过一定epoch迭代之后,模型效果不再提升,该学习率可能已经不再适应该模型。需要在训练过程中缩小学习率,进而提升模型。如何在训练过程中缩小学习率呢?我们可以使用keras中的回调函数ReduceLROnPlateau。与EarlyStopping配合使用,会非常方便。
为什么初始化一个非常小的学习率呢?因为初始的学习率过小,会需要非常多次的迭代才能使模型达到最优状态,训练缓慢。如果训练过程中不断缩小学习率,可以快速又精确的获得最优模型。
monitor:
监测的值,可以是accuracy,val_loss,val_accuracy
factor:
缩放学习率的值,学习率将以lr = lr*factor的形式被减少
patience:
当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
mode:
‘auto’,‘min’,‘max’之一 默认‘auto’就行
epsilon:
阈值,用来确定是否进入检测值的“平原区”
cooldown:
学习率减少后,会经过cooldown个epoch才重新进行正常操作
min_lr:
学习率最小值,能缩小到的下限
EarlyStopping的patience要比ReduceLROnPlateau的patience大一些才会有效果。
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
learning_rate_reduction = ReduceLROnPlateau(monitor='val_mse', patience=6, verbose=1, factor=0.5, min_lr=0.00001)
earlyStop = EarlyStopping(monitor='val_mse', min_delta=0, patience=30, mode='auto', verbose=1,restore_best_weights=True)
history = model.fit(train_X, train_y, epochs=150, batch_size=30, verbose=2,validation_data=(test_X, test_y),shuffle=False,callbacks=[learning_rate_reduction,earlyStop])