今天在跟随沐神的课看见了以前没见过SGD参数传入方式(才学没多久,见识浅陋):
trainer = torch.optim.SGD([{'params': params_1x},
{'params': net.fc.parameters(), 'lr': learning_rate * 10}],
lr=learning_rate, weight_decay=0.001)
传入了一个列表中包含了两个字典,作用是为了做到在最后一层中的学习率与其他层不一样。这是怎么做到的呢?
于是我看了看SGD类__init__函数源码:
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
maximize=maximize, foreach=foreach)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
我们可以看到在这个函数params只在最后的父类初始化函数中用到了,我继续查看SGD父类Optimizer类的__init__函数源码:
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._hook_for_profile()
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
self._warned_capturable_if_run_uncaptured = True
在第14行的if语句中它判断了传入的param_groups是否为一个字典,如果不为字典就将它变成键为"params",值为param_groups的字典。
并且在接下来的for loop中对param_groups进行迭代,然后调用函数add_param_group将每一个param_groups中的pram_group传入其中。所以先要弄清楚这个parameter怎么用还得查看add_param_group函数。这个函数有点长,涉及到本问题的核心代码就是:
for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError("parameter group didn't specify a value of required optimization parameter " +
name)
else:
param_group.setdefault(name, default)
这是一个遍历self.defaults.items()的一个for loop,从之前的代码可知defaults里面放的也是一些字典,存了一些关于优化器的超参数。for loop中的if语句是用于判断default中是否有与param_group相同的键,如果没有的话就将此时这个default中的item添加到param_group中去。
回看我们最初的问题,SGD初始化时传入的那个含有两个字典的列表,只要列表里的字典中含有与超参数名字相同的键,不就能够做到改变这层的超参数与其他的曾不一样了吗?这里第二个字典中有{'lr':learning_rate * 10}这个键值对就是这个作用!!
总结:其实这是一个比较简单的问题,但是在这个问题的解决过程当中,通过不断追踪参数在源码中的用法,我很好地锻炼了我的解决问题的能力,也增加了我对pytorch库的一些了解。
注:我是刚学深度学习没多久的菜鸟,若有大佬发现不足之处,还请多多指正!