LR_scheduler及warmup底层原理和代码分析

LR_scheduler

LR_scheduler是用于调节学习率lr的,在代码中,我们经常看到这样的一行代码

 scheduler.step()

通过这行代码来实现lr的更新的,那么其中的底层原理是什么呢?我们就进去看看


在pytorch代码中,各种类型scheduler大多基于_LRScheduler

我们就看看这个类的step()函数到底干了什么

     def step(self, epoch=None):
         # Raise a warning if old pattern is detected
         # https://github.com/pytorch/pytorch/issues/20124
         if self._step_count == 1:
             if not hasattr(self.optimizer.step, "_with_counter"):
                 warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                               "initialization. Please, make sure to call `optimizer.step()` before "
                               "`lr_scheduler.step()`. See more details at "
                               "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
 ​
             # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
             elif self.optimizer._step_count < 1:
                 warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                               "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                               "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                               "will result in PyTorch skipping the first value of the learning rate schedule. "
                               "See more details at "
                               "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
         self._step_count += 1
 ​
         class _enable_get_lr_call:
 ​
             def __init__(self, o):
                 self.o = o
 ​
             def __enter__(self):
                 self.o._get_lr_called_within_step = True
                 return self
 ​
             def __exit__(self, type, value, traceback):
                 self.o._get_lr_called_within_step = False
 ​
         with _enable_get_lr_call(self):
             if epoch is None:
                 self.last_epoch += 1  # 表示上一个epoch
                 values = self.get_lr()  # 计算学习率lr
             else:
                 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                 self.last_epoch = epoch  # 直接跳转到参数epoch
                 if hasattr(self, "_get_closed_form_lr"):
                     values = self._get_closed_form_lr()
                 else:
                     values = self.get_lr()
 ​
          # 对所有参数权重对应的lr进行修改
         for i, data in enumerate(zip(self.optimizer.param_groups, values)):
             param_group, lr = data
             param_group['lr'] = lr # 修改学习率
             self.print_lr(self.verbose, i, lr, epoch)
 ​
         self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
 ​

由上代码可知,step()的目的是计算计算新的学习率并对旧学习率进行修改,其中最重要的函数是get_lr(),我们接下来对这个函数进行分析

     def get_lr(self):
         # Compute learning rate using chainable form of the scheduler
         raise NotImplementedError

由于_LRScheduler类是一个基类,不表示任何学习率策略,我们选择最简单的StepLR学习策略(学习率阶梯式下降)来分析

     def get_lr(self):
         if not self._get_lr_called_within_step:
             warnings.warn("To get the last learning rate computed by the scheduler, "
                           "please use `get_last_lr()`.", UserWarning)
 ​
         if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): # 表示在一个阶梯上,不改变学习率
             return [group['lr'] for group in self.optimizer.param_groups]
         return [group['lr'] * self.gamma # 对所有学习率乘以一个小于1的小数,减小学习率
                 for group in self.optimizer.param_groups]
 ​

如果step()函数中有epoch参数,需要直接跳转到指定epoch,那么直接乘以固定的小数就不对了,这时候就需要函数_get_closed_form_lr()

     def _get_closed_form_lr(self):
         return [base_lr * self.gamma ** (self.last_epoch // self.step_size)  
                 for base_lr in self.base_lrs]

其中self.last_epoch之前在基类_LRScheduler中已经被赋值了self.last_epoch = epoch ,所以直接根据学习率变化公式计算处理

由上可知,get_lr()和_get_closed_form_lr()就是具体的学习率计算方法

这样,我们就可以根据不同的学习率计算方式设计自己的scheduler类了。


warmup

初始训练阶段,直接使用较大学习率会导致权重变化较大,出现振荡现象,使得模型不稳定,加大训练难度。而使用Warmup预热学习率,在开始的几个epoch,逐步增大学习率,如下图所示,使得模型逐渐趋于稳定,等模型相对稳定后再选择预先设置的基础学习率进行训练,使得模型收敛速度变得更快,模型效果更佳

warmup step

上图中的0-10epoch阶段就是一个warmup操作,学习率缓慢增加,10之后就是常规的学习率递减算法

原理上很简单,接下来从代码上进行分析,warmup可以有两种构成方式:

  1. 对已有的scheduler类进行包装重构

  2. 直接编写新的类

对于第一种情况,我们以CosineAnnealingLR类为例

         scheduler = CosineAnnealingLR( # pytorch自带的类
             optimizer=optimizer,
             eta_min=0.000001,
             T_max=(epochs - warmup_epoch) * n_iter_per_epoch)  
             
         scheduler = GradualWarmupScheduler( # 重构的类
             optimizer,
             multiplier=args.warmup_multiplier,
             after_scheduler=scheduler,
             warmup_epoch=warmup_epoch * n_iter_per_epoch)
 ​

其中,GradualWarmupScheduler就是基于CosineAnnealingLR重构的类,我们首先查看类中step()函数

     def step(self, epoch=None):
         if epoch is None:
             epoch = self.last_epoch + 1
         self.last_epoch = epoch
         if epoch > self.warmup_epoch: # 超过warmup范围,使用自带的类,也就是CosineAnnealingLR
             self.after_scheduler.step(epoch - self.warmup_epoch) # 注意CosineAnnealingLR要从0epoch开始,所以需要减去
         else:
             super(GradualWarmupScheduler, self).step(epoch)  # warmup范围,使用当前重构类的()
 ​

对于超过warmup范围,直接使用CosineAnnealingLR类,比较简单

对于warmup范围类,使用当前重构类的step()函数,因为也是继承于_LRScheduler类,所以step()同样是运用到get_lr()

     def get_lr(self):
         if self.last_epoch > self.warmup_epoch:  # 超过warmup范围,使用CosineAnnealingLR类的get_lr()
             return self.after_scheduler.get_lr()
         else:  # warmup范围,编写线性变化,也就是上图中0-10区间内的直线
             return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)
                     for base_lr in self.base_lrs]
 ​

对于第二种情况,step()无需构造,直接继承_LRScheduler,需要构造的是get_lr()函数,其中warmup范围外的代码与自带的CosineAnnealingLR类中get_lr()代码一样。

 

你可能感兴趣的:(pytorch,深度学习,人工智能)