摘要:文献【1】中除了权重衰减还利用了余弦退火(Cosine Annealing)以及Warm Restart,本文介绍这两种方法的原理及numpy和Keras的实现方法,其中Keras实现中继承回调函数Callbacks。
【1】DECOUPLED WEIGHT DECAY REGULARIZATION
论文中对学习率规划原理描述如下,公式(5)表明学习率随迭代次数的变化规律。
变量名称解释如下:
各上标、下标 i i i为run的序号,意思是第 i i i次的restart。
η m i n i \eta^i_{min} ηmini以及 η m a x i \eta^i_{max} ηmaxi为学习率的范围。
T c u r T_{cur} Tcur虽然写着是计算epoch的,但后面一句又说是随着iteration变化的。
T i T_i Ti是当前run总共的epoch数目。
对 T i T_i Ti和 T c u r T_{cur} Tcur,文中显示 T c u r T_{cur} Tcur可以是离散的小数,这里进行如下修改便于实现:
T c u r T i = T c u r × n b a t c h e s T i × n b a t c h e s = i t e r a t i o n T o t a l I t e r a t i o n s \frac{T_{cur}}{T_i}=\frac{T_{cur} \times n_{batches}}{T_i \times n_{batches}}=\frac{iteration}{TotalIterations} TiTcur=Ti×nbatchesTcur×nbatches=TotalIterationsiteration
这样就变成了当前的iteration数目的计量。
具体实现时,在训练过程中,轮到相应的epoch时重新设置优化器的 T o t a l I t e r a t i o n s TotalIterations TotalIterations并初始化 T c u r T_{cur} Tcur。
下面先用numpy简单可视化一下这样操作后学习率的变化。实际使用中一般继承Tensorflow(Keras)或者Pytorch等框架自带的学习率规划类。
利用Numpy可视化余弦退火和Warm Restart之后的学习率变化情况。
import numpy as np
import matplotlib.pyplot as plt
def compute_eta_t(eta_min, eta_max, T_cur, Ti):
'''Equation (5).
# Arguments
eta_min,eta_max,T_cur,Ti are same as equation.
# Returns
eta_t
'''
pi = np.pi
eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * T_cur / Ti) + 1)
return eta_t
# 每Ti个epoch进行一次restart。
Ti = [20, 40, 80, 160]
n_batches = 200
eta_ts = []
for ti in Ti:
T_cur = np.arange(0, ti, 1 / n_batches)
for t_cur in T_cur:
eta_ts.append(compute_eta_t(0, 1, t_cur, ti))
n_iterations = sum(Ti) * n_batches
epoch = np.arange(0, n_iterations) / n_batches
plt.plot(epoch, eta_ts)
下面进行余弦退火和warm restart的框架实现。
使用Keras框架,继承Callback实现余弦退火,warm restart可以手动循环实现。。。
首先定义余弦退火类,在每个batch结束后计算eta。
from keras import *
class CosineAnnealing(callbacks.Callback):
"""Cosine annealing according to DECOUPLED WEIGHT DECAY REGULARIZATION.
# Arguments
eta_max: float, eta_max in eq(5).
eta_min: float, eta_min in eq(5).
total_iteration: int, Ti in eq(5).
iteration: int, T_cur in eq(5).
verbose: 0 or 1.
"""
def __init__(self, eta_max=1, eta_min=0, total_iteration=0, iteration=0, verbose=0, **kwargs):
super(CosineAnnealing, self).__init__()
global lr_list
lr_list = []
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
self.total_iteration = total_iteration
self.iteration = iteration
def on_train_begin(self, logs=None):
self.lr = K.get_value(self.model.optimizer.lr)
def on_train_end(self, logs=None):
K.set_value(self.model.optimizer.lr, self.lr)
def on_batch_end(self, epoch, logs=None):
self.iteration += 1
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
eta_t = self.eta_min + (self.eta_max - self.eta_min) * 0.5 * (1 + np.cos(np.pi * self.iteration / self.total_iteration))
new_lr = self.lr * eta_t
K.set_value(self.model.optimizer.lr, new_lr)
if self.verbose > 0:
print('\nEpoch %05d: CosineAnnealing '
'learning rate to %s.' % (epoch + 1, new_lr))
lr_list.append(logs['lr'])
下面是数据及模型的创建。
import keras
from keras import layers
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras.backend as K
# 准备数据集
num_train, num_test = 2000, 100
num_features = 200
true_w, true_b = np.ones((num_features, 1)) * 0.01, 0.05
features = np.random.normal(0, 1, (num_train + num_test, num_features))
noises = np.random.normal(0, 1, (num_train + num_test, 1)) * 0.01
labels = np.dot(features, true_w) + true_b + noises
train_data, test_data = features[:num_train, :], features[num_train:, :]
train_labels, test_labels = labels[:num_train], labels[num_train:]
# 选择模型
model = keras.models.Sequential([
layers.Dense(units=128, activation='relu', input_dim=200),
layers.Dense(128, activation='relu', kernel_regularizer=keras.regularizers.l2(0.00)),
layers.Dense(1)
])
model.summary()
model.compile(optimizer='adam',
loss='mse',
metrics=['mse'])
最后是训练过程,在这里加上前面定义的回调函数,并手动实现WarmRestart。
epochs = [2, 4, 8, 16, 32]
lr_lists = []
for Ti in epochs:
reduce_lr = CosineAnnealing(eta_max=1, eta_min=0, total_iteration=Ti * (2000 // 16), iteration=0, verbose=0)
hist1 = model.fit(train_data, train_labels, batch_size=16, epochs=Ti, validation_data=[test_data, test_labels], callbacks=[reduce_lr])
lr_lists += lr_list
plt.plot(lr_lists)