PyTorch学习率衰减函数

在网络训练过程中,经常要根据实际情况改变学习率以适应当前阶段的学习。

PyTorch中给出的lr_scheduler模块就可以实现多种学习率衰减。

1 导入模块

from torch.optim import lr_scheduler

2 在训练代码中optimizer定义后,规定衰减策略

衰减策略具体解释参考链接
https://www.jianshu.com/p/9643cba47655

  • LambdaLR
    PyTorch学习率衰减函数_第1张图片
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)
  • StepLR
    PyTorch学习率衰减函数_第2张图片
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)
  • MultiStepLR
    PyTorch学习率衰减函数_第3张图片
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
  • ExponentialLR

PyTorch学习率衰减函数_第4张图片

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)
  • CosineAnnealingLR
    PyTorch学习率衰减函数_第5张图片
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
  • ReduceLROnPlateau
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

3 在训练过程中,需要衰减的节点加上这一句:

# 使得学习率按照制定策略衰减一次
scheduler.step()

4 另:查看当前optimizer的学习率

print(optimizer.state_dict()['param_groups'][0]['lr'])

你可能感兴趣的:(机器学习)