torch.optim.lr_scheduler是PyTorch中负责调整学习率的模块,常和torch.optim.Optimizer配合使用。
optimizer模块的源码学习可参见:torch.optim.optimizer源码阅读和灵活使用
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
# 读取相应的Optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
# last_epoch表示上一轮epoch的序号;若为-1,表示当前训练是从头训练
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else: # last_epoch不为-1,表示当前训练是断点训练,必须有初始学习率
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified in param_groups[{}] when resuming an optimizer".format(i))
# 读取每一组待优化变量的初始学习率
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, '_with_counter', False):
# `optimizer.step()` has already been replaced, return.
return method
# 建立一个method的弱引用。弱引用不增加对象的引用计数,只存在弱引用的对象是可被垃圾回收的;
# 弱引用可以解决循环引用的问题。
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__ # __func__是method的底层实现,不跟具体的实例绑定
cls = instance_ref().__class__ # method的所属类
del method
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper
# 通过装饰器来为optimizer.step添加计数功能,并初始化计数器
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.step() # 更新学习率
lr_scheduler在构造函数中主要是获取optimizer并向其添加step计数功能,然后更新一次学习率。
弱引用相关:Python弱引用的使用
__func__相关:Python(类)实例方法的特殊属性
装饰器和functools.wraps相关:探究functools模块wraps装饰器的用途
step函数:
def step(self, epoch=None):
# 由于lr_scheduler在构造函数中已经step过一次,故lr_scheduler.step()一定要在optimizer.step()之后。
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1 # lr_scheduler的step计数
# 支持上下文管理器协议的类
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
with _enable_get_lr_call(self):
if epoch is None: # 未指明从某个具体的epoch开始训练
self.last_epoch += 1 # 更新epoch
values = self.get_lr() # 计算新的lr,与具体的lr_scheduler类型有关
else: # 指定epoch
# EPOCH_DEPRECATION_WARNING是一个提示信息:epoch参数即将被移除
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr() # 正式移除epoch之前的lr计算方法
else:
values = self.get_lr()
# 更新optimizer中保存的lr
for param_group, lr in zip(self.optimizer.param_groups, values):
param_group['lr'] = lr
# _last_lr记录上一轮次更新的lr值
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
step函数主要进行lr的实时计算以及相关参数的更新,包括epoch、lr和optimizer中保存的实时lr。
上下文管理器相关:python中的__enter__和__exit
学习率计算:get_lr
# 计算当前更新轮次的学习率,与具体的lr更新策略有关,由子类实现
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
获得上一轮次训练的lr值:
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
return self._last_lr
获取lr_scheduler的相关参数:
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which is not the optimizer.
"""
# lr_scheduler中虽然有optimizer属性来记录与其相对应的优化器,但state_dict中并不包括优化器参数
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
需要注意的是,lr_scheduler的state_dict返回的是scheduler的所有属性,所有不同的scheduler返回的参数各不相同。
加载已有的lr_scheduler参数:
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict) # 用state_dict更新当前lr_scheduler的参数
__dict__相关:Python ____dict__与dir()区别
以cosine学习率调整策略来具体学习lr_scheduler:
lr计算公式:
η t + 1 = η m i n + 1 2 ( η t − η m i n ) ( 1 + cos ( T c u r + 1 T m a x π ) 1 + cos ( T c u r T m a x π ) ) , T c u r ≠ ( 2 k + 1 ) T m a x \eta_{t+1} = \eta_{min} + \frac{1}{2}(\eta_{t} - \eta_{min})\left(\frac{1 + \cos\left(\frac{T_{cur}+1}{T_{max}}\pi\right)}{1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)}\right), T_{cur} \neq (2k+1)T_{max} ηt+1=ηmin+21(ηt−ηmin)⎝⎛1+cos(TmaxTcurπ)1+cos(TmaxTcur+1π)⎠⎞,Tcur=(2k+1)Tmax
η t + 1 = η t + 1 2 ( η m a x − η m i n ) ( 1 − cos ( 1 T m a x π ) ) , T c u r = ( 2 k + 1 ) T m a x \eta_{t+1} = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), T_{cur} = (2k+1)T_{max} ηt+1=ηt+21(ηmax−ηmin)(1−cos(Tmax1π)),Tcur=(2k+1)Tmax
由于上述公式是递归式,所以lr可以在get_lr之外被修改,若lr尽在get_lr中计算,则公式可统一为:
η t = η m i n + 1 2 ( η m a x − η m i n ) ( 1 + cos ( T c u r T m a x π ) ) \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) ηt=ηmin+21(ηmax−ηmin)(1+cos(TmaxTcurπ))
class CosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
SGDR: Stochastic Gradient Descent with Warm Restarts
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
# 在lr_scheduler的step函数中,last_epoch+1发生在get_lr之前,故get_lr中的last_epoch是当前更新轮次
if self.last_epoch == 0: # step只执行过一次,即当前轮此为0,对应的学习率是初始学习率
return self.base_lrs
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
# T_{cur} = (2k+1)T_{max}:
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in
zip(self.base_lrs, self.optimizer.param_groups)]
# base_lr是初始学习率,group['lr']是上一轮的学习率
# T_{cur} \neq (2k+1)T_{max}:
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
# step的辅助函数
def _get_closed_form_lr(self):
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]