torch.optim.lr_scheduler
模块提供了一些根据 epoch
迭代次数来调整学习率 lr
的方法。为了能够让损失函数最终达到收敛的效果,通常 lr
随着迭代次数的增加而减小时能够得到较好的效果。torch.optim.lr_scheduler.ReduceLROnPlateau
则提供了基于训练中某些测量值使学习率动态下降的方法。
学习率的调整应该放在optimizer更新之后,下面是Demo示例:
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
# 1.进行参数的更新
optimizer.step()
# 2.对学习率进行更新
scheduler.step()
# 注意:现在,在很多祖传代码中,scheduler.step()的位置可能是在参数更新optimizer.step()之前
# 检查您的pytorch版本如果是V1.1.0+,那么需要将scheduler.step()在optimizer.step()之后调用
PyTorch 1.1.0
之前, 学习率更新操作scheduler.step()
会在optimizer.step()
操作之前调用;v1.1.0
修改了这种调用机制。如果在optimizer.step()
之前调用scheduler.step()
, 会自动跳过第一次lr
的更新。如果更新了v1.1.0
后您的结果不一样了,请确认是不是在这里的调用顺序有误。
为了进一步说明 lr_scheduler
的机制,我们首先需要了解一下 optimizer
的结构,以 Adam()
为例(所有 optimizers
都继承自 torch.optim.Optimizer
类)。
对于 class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
- params (iterable):需要优化的网络参数,传进来的网络参数必须是Iterable。
- 优化一个网络,网络的每一层看做一个parameter group,一整个网络就是parameter groups(一般给赋值为net.parameters()——generator的字典);
- 优化多个网络,有两种方法:
- 多个网络的参数合并到一起,形如[*net_1.parameters(), *net_2.parameters()]或itertools.chain(net_1.parameters(), net_2.parameters());
- 当成多个网络优化,让多个网络的学习率各不相同,形如[{‘params’: net_1.parameters()}, {‘params’: net_2.parameters()}]
- lr (float, optional):学习率;
- betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999));
- eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8);
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0);
- amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False)。
optimizer.defaults
: 字典,存放这个优化器的一些初始参数,有:'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'
。optimizer.param_groups
:列表,每个元素都是一个字典,每个元素包含的关键字有:'params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'
,params
类是各个网络的参数放在了一起。import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
import itertools
initial_lr = 0.1
class model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
def forward(self, x):
pass
net_1 = model()
net_2 = model()
optimizer_1 = torch.optim.Adam(net_1.parameters(), lr = initial_lr)
print("******************optimizer_1*********************")
print("optimizer_1.defaults:", optimizer_1.defaults)
print("optimizer_1.param_groups长度:", len(optimizer_1.param_groups))
print("optimizer_1.param_groups一个元素包含的键:", optimizer_1.param_groups[0].keys())
print()
####################################################################################
******************optimizer_1*********************
optimizer_1.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_1.param_groups长度: 1
optimizer_1.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################
optimizer_2 = torch.optim.Adam([*net_1.parameters(), *net_2.parameters()], lr = initial_lr)
# optimizer_2 = torch.opotim.Adam(itertools.chain(net_1.parameters(), net_2.parameters())) # 和上一行作用相同
print("******************optimizer_2*********************")
print("optimizer_2.defaults:", optimizer_2.defaults)
print("optimizer_2.param_groups长度:", len(optimizer_2.param_groups))
print("optimizer_2.param_groups一个元素包含的键:", optimizer_2.param_groups[0].keys())
print()
####################################################################################
******************optimizer_2*********************
optimizer_2.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_2.param_groups长度: 1
optimizer_2.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################
optimizer_3 = torch.optim.Adam([{"params": net_1.parameters()}, {"params": net_2.parameters()}], lr = initial_lr)
print("******************optimizer_3*********************")
print("optimizer_3.defaults:", optimizer_3.defaults)
print("optimizer_3.param_groups长度:", len(optimizer_3.param_groups))
print("optimizer_3.param_groups一个元素包含的键:", optimizer_3.param_groups[0].keys())
####################################################################################
******************optimizer_3*********************
optimizer_3.defaults: {'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
optimizer_3.param_groups长度: 2
optimizer_3.param_groups一个元素包含的键: dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])
####################################################################################
lr_scheduler
更新optimizer
的lr
,是更新的optimizer.param_groups[n][‘lr’]
,而不是optimizer.defaults[‘lr’]
以lambdaLR为例:
CLASS torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)
- optimizer (Optimizer) – Wrapped optimizer.
- lr_lambda (function or list) – A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups.
- last_epoch (int) – The index of last epoch. Default: -1.
- verbose (bool) – If True, prints a message to stdout for each update. Default: False.
Demo示例:
# Assuming optimizer has two groups.
lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
train(...)
validate(...)
scheduler.step()
CLASS torch.optim.lr_scheduler.LambdaLR
实例函数
- get_last_lr()
返回上次计算的
lr
- print_lr(is_verbose, group, lr, epoch=None)
打印当前
lr
- state_dict()
Returns the state of the scheduler as a dict.
It contains an entry for every variable in self.dict which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
- load_state_dict(state_dict)
Loads the schedulers state.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
后期会写一个完整的训练代码出来,给后面的项目进行参考
Pytorch官方文档
csdn中关于lr_scheduler用法的详细介绍