pytorch学习之:使用 warm up 方法构造优化调度器优化神经网络参数

文章目录

  • 代码
  • 调用

代码

  • 构造优化调度器,根据当前的 epoch 调整训练的 learning rate

class ScheduledOptim:
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self
                 , optimizer
                 , n_warmup_epochs=6
                 , sustain_epochs=0
                 , lr_max=1e-3
                 , lr_min=1e-5
                 , lr_exp_decay=0.4):

        self._optimizer = optimizer
        self.n_warmup_epochs = n_warmup_epochs
        self.sustain_epochs = sustain_epochs
        self.init_lr = lr_min
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr_exp_decay = lr_exp_decay

    def step_and_update_lr(self, epoch):
        "Step with the inner optimizer"
        self._update_learning_rate(epoch)
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _update_learning_rate(self, epoch):
        ''' Learning rate scheduling per epoch '''

        if epoch < self.n_warmup_epochs:
            lr = (self.lr_max - self.lr_min) / self.n_warmup_epochs * epoch + self.init_lr
        elif epoch < self.n_warmup_epochs + self.sustain_epochs:
            lr = self.lr_max
        else:
            lr = (self.lr_max - self.lr_min) \
                 * self.lr_exp_decay ** (epoch - self.n_warmup_epochs - self.sustain_epochs) \
                 + self.lr_min
        # return lr
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

    def draw(self, epochs):
    	"""
    	画出优化器的变化趋势 plot
    	"""
        lrs = []
        for i in range(epochs):
            lr = self._update_learning_rate(i)
            lrs.append(lr)
        import matplotlib.pylab as plt
        plt.plot(lrs)
        plt.show()

调用

optimizer = Adam(model.parameters()
                            , lr=5e-4
                            , eps=1e-16
                            , betas=(0.9, 0.999)
                            )
optim_schedule = ScheduledOptim(optimizer)

for epoch in range(epochs):
	# 其他代码....
	
	# 通过 schedule 根据不同的 epoch 进行 lr 更新
	optim_schedule.zero_grad()
	loss.backward()
	optim_schedule.step_and_update_lr(epoch)

你可能感兴趣的:(Pytorch学习,pytorch,学习,神经网络)