【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)

1. 概述

       在论文《SGDR: Stochastic Gradient Descent with Warm Restarts》中主要介绍了带重启的随机梯度下降算法(SGDR),其中就引入了余弦退火的学习率下降方式。
       当我们使用梯度下降算法来优化目标函数的时候,当越来越接近Loss值的全局最小值时,学习率应该变得更小来使得模型尽可能接近这一点,而余弦退火(cosine annealing)可以通过余弦函数来降低学习率。余弦函数中随着x的增加余弦值首先缓慢下降,然后加速下降,再次缓慢下降。这种下降模式能和学习率配合,以一种十分有效的计算方式来产生很好的效果。
【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)_第1张图片

       另外,因为我们的目标优化函数可能是多峰的,除了全局最优解之外还有多个局部最优解。在训练时梯度下降算法可能陷入局部最小值,此时可以通过突然提高学习率,来“跳出”局部最小值并找到通向全局最小值的路径。这种方式称为带重启的随机梯度下降方法。
【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)_第2张图片

2. 原理

       论文介绍最简单的热重启的方法。当执行完T_i个epoch之后就会开始热重启(warm restart),而下标i就是指第几次的restart,其中重启并不是重头开始,而是通过增加学习率来模拟,并且重启之后使得旧的x_i作为初始解,这里的x_i就是通过梯度下降求解loss函数的解,也就是神经网络中的权重,因为重启就是为了通过增大学习率来跳过局部最优,所以需要将x_i置为旧值。
       其实现公示如下:

在这里插入图片描述
       表达式中字符的含义:
● i就是第几次run(索引值)。
● η_maxi和η_mini分别表示学习率的最大值和最小值,定义了学习率的范围。论文中提到在每次restart之后,减小η_maxi和η_mini的值,但是为了简单,论文中也保持η_maxi和η_mini在每次restart之后仍然保持不变。
● T_cur则表示当前执行了多少个epoch,但是T_cur是在每个batch运行之后就会更新,而此时一个epoch还没有执行完,所以T_cur的值可以为小数。例如总样本为80,每个batch的大小是16,那么在一个epoch中就会循环5次读入batch,那么在第一个epoch中执行完第一个batch后,T_cur的值就更新为1/5=0.2,以此类推。
● T_i表示第i次run中总的epoch数。当涉及到重启时,论文中提到为了提高性能,开始会初始化一个比较小的T_i,在每次restart后,T_i会以乘以一个T_mult的方式增加。

3. Pytorch代码

import torch
import math
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingLR_with_Restart(_LRScheduler):
    """Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::

        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
        \cos(\frac{T_{cur}}{T_{max}}\pi))

    When last_epoch=-1, sets initial lr as lr.

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. The original pytorch
    implementation only implements the cosine annealing part of SGDR,
    I added my own implementation of the restarts part.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations.
        T_mult (float): Increase T_max by a factor of T_mult
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
        model (pytorch model): The model to save.
        out_dir (str): Directory to save snapshots
        take_snapshot (bool): Whether to save snapshots at every restart

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_max, T_mult, model, out_dir, take_snapshot, eta_min=0, last_epoch=-1):
        self.T_max = T_max
        self.T_mult = T_mult
        self.Te = self.T_max
        self.eta_min = eta_min
        self.current_epoch = last_epoch

        self.model = model
        self.out_dir = out_dir
        self.take_snapshot = take_snapshot

        self.lr_history = []

        super(CosineAnnealingLR_with_Restart, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        new_lrs = [self.eta_min + (base_lr - self.eta_min) *
                   (1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2
                   for base_lr in self.base_lrs]

        self.lr_history.append(new_lrs)
        return new_lrs

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

        ## restart
        if self.current_epoch == self.Te:
            print("restart at epoch {:03d}".format(self.last_epoch + 1))

            if self.take_snapshot:
                torch.save({
                    'epoch': self.T_max,
                    'state_dict': self.model.state_dict()
                }, self.out_dir + "Weight/" + 'snapshot_e_{:03d}.pth.tar'.format(self.T_max))

            ## reset epochs since the last reset
            self.current_epoch = 0

            ## reset the next goal
            self.Te = int(self.Te * self.T_mult)
            self.T_max = self.T_max + self.Te

你可能感兴趣的:(算法,神经网络,深度学习,pytorch,python)