LR_scheduler是用于调节学习率lr的,在代码中,我们经常看到这样的一行代码
scheduler.step()
通过这行代码来实现lr的更新的,那么其中的底层原理是什么呢?我们就进去看看
在pytorch代码中,各种类型scheduler大多基于_LRScheduler类
我们就看看这个类的step()函数到底干了什么
def step(self, epoch=None): # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: if not hasattr(self.optimizer.step, "_with_counter"): warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " "initialization. Please, make sure to call `optimizer.step()` before " "`lr_scheduler.step()`. See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) # Just check if there were two first lr_scheduler.step() calls before optimizer.step() elif self.optimizer._step_count < 1: warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " "In PyTorch 1.1.0 and later, you should call them in the opposite order: " "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " "will result in PyTorch skipping the first value of the learning rate schedule. " "See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) self._step_count += 1 class _enable_get_lr_call: def __init__(self, o): self.o = o def __enter__(self): self.o._get_lr_called_within_step = True return self def __exit__(self, type, value, traceback): self.o._get_lr_called_within_step = False with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 # 表示上一个epoch values = self.get_lr() # 计算学习率lr else: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch # 直接跳转到参数epoch if hasattr(self, "_get_closed_form_lr"): values = self._get_closed_form_lr() else: values = self.get_lr() # 对所有参数权重对应的lr进行修改 for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group['lr'] = lr # 修改学习率 self.print_lr(self.verbose, i, lr, epoch) self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
由上代码可知,step()的目的是计算计算新的学习率并对旧学习率进行修改,其中最重要的函数是get_lr(),我们接下来对这个函数进行分析
def get_lr(self): # Compute learning rate using chainable form of the scheduler raise NotImplementedError
由于_LRScheduler类是一个基类,不表示任何学习率策略,我们选择最简单的StepLR学习策略(学习率阶梯式下降)来分析
def get_lr(self): if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): # 表示在一个阶梯上,不改变学习率 return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] * self.gamma # 对所有学习率乘以一个小于1的小数,减小学习率 for group in self.optimizer.param_groups]
如果step()函数中有epoch参数,需要直接跳转到指定epoch,那么直接乘以固定的小数就不对了,这时候就需要函数_get_closed_form_lr()
def _get_closed_form_lr(self): return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs]
其中self.last_epoch之前在基类_LRScheduler中已经被赋值了self.last_epoch = epoch ,所以直接根据学习率变化公式计算处理
由上可知,get_lr()和_get_closed_form_lr()就是具体的学习率计算方法
这样,我们就可以根据不同的学习率计算方式设计自己的scheduler类了。
初始训练阶段,直接使用较大学习率会导致权重变化较大,出现振荡现象,使得模型不稳定,加大训练难度。而使用Warmup预热学习率,在开始的几个epoch,逐步增大学习率,如下图所示,使得模型逐渐趋于稳定,等模型相对稳定后再选择预先设置的基础学习率进行训练,使得模型收敛速度变得更快,模型效果更佳
上图中的0-10epoch阶段就是一个warmup操作,学习率缓慢增加,10之后就是常规的学习率递减算法
原理上很简单,接下来从代码上进行分析,warmup可以有两种构成方式:
对已有的scheduler类进行包装重构
直接编写新的类
对于第一种情况,我们以CosineAnnealingLR类为例
scheduler = CosineAnnealingLR( # pytorch自带的类 optimizer=optimizer, eta_min=0.000001, T_max=(epochs - warmup_epoch) * n_iter_per_epoch) scheduler = GradualWarmupScheduler( # 重构的类 optimizer, multiplier=args.warmup_multiplier, after_scheduler=scheduler, warmup_epoch=warmup_epoch * n_iter_per_epoch)
其中,GradualWarmupScheduler就是基于CosineAnnealingLR重构的类,我们首先查看类中step()函数
def step(self, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch if epoch > self.warmup_epoch: # 超过warmup范围,使用自带的类,也就是CosineAnnealingLR self.after_scheduler.step(epoch - self.warmup_epoch) # 注意CosineAnnealingLR要从0epoch开始,所以需要减去 else: super(GradualWarmupScheduler, self).step(epoch) # warmup范围,使用当前重构类的()
对于超过warmup范围,直接使用CosineAnnealingLR类,比较简单
对于warmup范围类,使用当前重构类的step()函数,因为也是继承于_LRScheduler类,所以step()同样是运用到get_lr()
def get_lr(self): if self.last_epoch > self.warmup_epoch: # 超过warmup范围,使用CosineAnnealingLR类的get_lr() return self.after_scheduler.get_lr() else: # warmup范围,编写线性变化,也就是上图中0-10区间内的直线 return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) for base_lr in self.base_lrs]
对于第二种情况,step()无需构造,直接继承_LRScheduler,需要构造的是get_lr()函数,其中warmup范围外的代码与自带的CosineAnnealingLR类中get_lr()代码一样。