自定义学习率

warm up + 自定义学习率策略(如cos)

from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler


class CustomScheduler(_LRScheduler):
    def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
        self.base_lr = base_lr
        self.warmup_lr_init = 0.0001
        self.max_steps: int = max_steps
        self.warmup_steps: int = warmup_steps
        self.power = 2
        self.eta_min = eta_min
        super(CustomScheduler, self).__init__(optimizer, -1, False)
        self.last_epoch = last_epoch

    def get_warmup_lr(self):
        alpha = float(self.last_epoch) / float(self.warmup_steps)
        return [self.base_lr * alpha for _ in self.optimizer.param_groups]

    def get_lr(self):
        if self.last_epoch == -1:
            return [self.warmup_lr_init for _ in self.optimizer.param_groups]
        if self.last_epoch < self.warmup_steps:
            return self.get_warmup_lr()
        else:
            alpha = self.func()
            return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]

    def func(self):
        alpha = (
            1
            - float(self.last_epoch - self.warmup_steps)
            / float(self.max_steps - self.warmup_steps))
        return alpha


class PolyScheduler(CustomScheduler):
    def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
        super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)

    def func(self):
        alpha = pow(
            1
            - float(self.last_epoch - self.warmup_steps)
            / float(self.max_steps - self.warmup_steps),
            self.power,
        )
        return alpha


class CosineScheduler(CustomScheduler):
    def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
        super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
        
    def func(self):
        alpha = cos(
            pi / 2
            * float(self.last_epoch - self.warmup_steps)
            / float(self.max_steps - self.warmup_steps))
        return alpha

你可能感兴趣的:(pytorch学习笔记,学习,python,深度学习)