torch.optim.lr_scheduler源码和cosine学习率策略学习

torch.optim.lr_scheduler是PyTorch中负责调整学习率的模块,常和torch.optim.Optimizer配合使用。
optimizer模块的源码学习可参见:torch.optim.optimizer源码阅读和灵活使用

class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        
        # 读取相应的Optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # last_epoch表示上一轮epoch的序号;若为-1,表示当前训练是从头训练
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else: # last_epoch不为-1,表示当前训练是断点训练,必须有初始学习率
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified in param_groups[{}] when resuming an optimizer".format(i))
        # 读取每一组待优化变量的初始学习率
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.last_epoch = last_epoch

        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method

            # 建立一个method的弱引用。弱引用不增加对象的引用计数,只存在弱引用的对象是可被垃圾回收的;
            # 弱引用可以解决循环引用的问题。
            instance_ref = weakref.ref(method.__self__)
            # Get the unbound method for the same purpose.
            func = method.__func__  # __func__是method的底层实现,不跟具体的实例绑定
            cls = instance_ref().__class__  # method的所属类
            del method

            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)

            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper

        # 通过装饰器来为optimizer.step添加计数功能,并初始化计数器
        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0

        self.step() # 更新学习率

lr_scheduler在构造函数中主要是获取optimizer并向其添加step计数功能,然后更新一次学习率。

弱引用相关:Python弱引用的使用

__func__相关:Python(类)实例方法的特殊属性

装饰器和functools.wraps相关:探究functools模块wraps装饰器的用途

step函数:

def step(self, epoch=None):
        # 由于lr_scheduler在构造函数中已经step过一次,故lr_scheduler.step()一定要在optimizer.step()之后。
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)     
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        
        self._step_count += 1 # lr_scheduler的step计数

        # 支持上下文管理器协议的类
        class _enable_get_lr_call:
            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False

        with _enable_get_lr_call(self):
            if epoch is None:  # 未指明从某个具体的epoch开始训练
                self.last_epoch += 1   # 更新epoch
                values = self.get_lr() # 计算新的lr,与具体的lr_scheduler类型有关
            else:  # 指定epoch
                # EPOCH_DEPRECATION_WARNING是一个提示信息:epoch参数即将被移除
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr() # 正式移除epoch之前的lr计算方法
                else:
                    values = self.get_lr()

        # 更新optimizer中保存的lr
        for param_group, lr in zip(self.optimizer.param_groups, values):
            param_group['lr'] = lr

        # _last_lr记录上一轮次更新的lr值
        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

step函数主要进行lr的实时计算以及相关参数的更新,包括epoch、lr和optimizer中保存的实时lr。

上下文管理器相关:python中的__enter__和__exit

学习率计算:get_lr

# 计算当前更新轮次的学习率,与具体的lr更新策略有关,由子类实现
def get_lr(self):
    # Compute learning rate using chainable form of the scheduler
    raise NotImplementedError

获得上一轮次训练的lr值:

def get_last_lr(self):
    """ Return last computed learning rate by current scheduler.
        """
    return self._last_lr

获取lr_scheduler的相关参数:

def state_dict(self):
    """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which is not the optimizer.
    """
    # lr_scheduler中虽然有optimizer属性来记录与其相对应的优化器,但state_dict中并不包括优化器参数
    return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

需要注意的是,lr_scheduler的state_dict返回的是scheduler的所有属性,所有不同的scheduler返回的参数各不相同。
加载已有的lr_scheduler参数:

def load_state_dict(self, state_dict):
    """Loads the schedulers state.
    Arguments:
        state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`.
    """
    self.__dict__.update(state_dict) # 用state_dict更新当前lr_scheduler的参数

__dict__相关:Python ____dict__与dir()区别

以cosine学习率调整策略来具体学习lr_scheduler:

lr计算公式:

η t + 1 = η m i n + 1 2 ( η t − η m i n ) ( 1 + cos ⁡ ( T c u r + 1 T m a x π ) 1 + cos ⁡ ( T c u r T m a x π ) ) , T c u r ≠ ( 2 k + 1 ) T m a x \eta_{t+1} = \eta_{min} + \frac{1}{2}(\eta_{t} - \eta_{min})\left(\frac{1 + \cos\left(\frac{T_{cur}+1}{T_{max}}\pi\right)}{1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)}\right), T_{cur} \neq (2k+1)T_{max} ηt+1=ηmin+21(ηtηmin)1+cos(TmaxTcurπ)1+cos(TmaxTcur+1π),Tcur=(2k+1)Tmax

η t + 1 = η t + 1 2 ( η m a x − η m i n ) ( 1 − cos ⁡ ( 1 T m a x π ) ) , T c u r = ( 2 k + 1 ) T m a x \eta_{t+1} = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), T_{cur} = (2k+1)T_{max} ηt+1=ηt+21(ηmaxηmin)(1cos(Tmax1π)),Tcur=(2k+1)Tmax

由于上述公式是递归式,所以lr可以在get_lr之外被修改,若lr尽在get_lr中计算,则公式可统一为:

η t = η m i n + 1 2 ( η m a x − η m i n ) ( 1 + cos ⁡ ( T c u r T m a x π ) ) \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) ηt=ηmin+21(ηmaxηmin)(1+cos(TmaxTcurπ))

class CosineAnnealingLR(_LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
    SGDR: Stochastic Gradient Descent with Warm Restarts
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
    """

    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
        self.T_max = T_max
        self.eta_min = eta_min
        super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        # 在lr_scheduler的step函数中,last_epoch+1发生在get_lr之前,故get_lr中的last_epoch是当前更新轮次
        if self.last_epoch == 0: # step只执行过一次,即当前轮此为0,对应的学习率是初始学习率
            return self.base_lrs
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            # T_{cur} = (2k+1)T_{max}:
            return [group['lr'] + (base_lr - self.eta_min) *
                    (1 - math.cos(math.pi / self.T_max)) / 2
                    for base_lr, group in
                    zip(self.base_lrs, self.optimizer.param_groups)]
        # base_lr是初始学习率,group['lr']是上一轮的学习率
        # T_{cur} \neq (2k+1)T_{max}:
        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
                (group['lr'] - self.eta_min) + self.eta_min
                for group in self.optimizer.param_groups]

    # step的辅助函数
    def _get_closed_form_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs]

你可能感兴趣的:(PyTorch)