自定义lr代码

自定义学习率代码可以加快训练的速度,下面是代码的内容和学习率的截图。

import matplotlib.pyplot as plt
def lrfn(epoch):
    LR_START = 0.00001
    LR_MAX = 0.00005 *8#* strategy.num_replicas_in_sync
    LR_MIN = 0.00001
    LR_RAMPUP_EPOCHS = 5
    LR_SUSTAIN_EPOCHS = 5
    LR_EXP_DECAY = .95
    
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)
EPOCHS=50
rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

自定义lr代码_第1张图片

你可能感兴趣的:(keras)