Pytorch中optimizer类初始化传入参数分析(分析源码)

今天在跟随沐神的课看见了以前没见过SGD参数传入方式(才学没多久,见识浅陋):

trainer = torch.optim.SGD([{'params': params_1x}, 
                           {'params': net.fc.parameters(), 'lr': learning_rate * 10}],
                          lr=learning_rate, weight_decay=0.001) 

传入了一个列表中包含了两个字典,作用是为了做到在最后一层中的学习率与其他层不一样。这是怎么做到的呢?

于是我看了看SGD类__init__函数源码:

    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        maximize=maximize, foreach=foreach)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

我们可以看到在这个函数params只在最后的父类初始化函数中用到了,我继续查看SGD父类Optimizer类的__init__函数源码:

    def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        self.defaults = defaults
        self._hook_for_profile()
        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))
        self.state = defaultdict(dict)
        self.param_groups = []
        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]
        for param_group in param_groups:
            self.add_param_group(param_group)
        self._warned_capturable_if_run_uncaptured = True

在第14行的if语句中它判断了传入的param_groups是否为一个字典,如果不为字典就将它变成键为"params",值为param_groups的字典。

并且在接下来的for loop中对param_groups进行迭代,然后调用函数add_param_group将每一个param_groups中的pram_group传入其中。所以先要弄清楚这个parameter怎么用还得查看add_param_group函数。这个函数有点长,涉及到本问题的核心代码就是:

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                 name)
            else:
                param_group.setdefault(name, default)

这是一个遍历self.defaults.items()的一个for loop,从之前的代码可知defaults里面放的也是一些字典,存了一些关于优化器的超参数。for loop中的if语句是用于判断default中是否有与param_group相同的键,如果没有的话就将此时这个default中的item添加到param_group中去。

回看我们最初的问题,SGD初始化时传入的那个含有两个字典的列表,只要列表里的字典中含有与超参数名字相同的键,不就能够做到改变这层的超参数与其他的曾不一样了吗?这里第二个字典中有{'lr':learning_rate * 10}这个键值对就是这个作用!!

总结:其实这是一个比较简单的问题,但是在这个问题的解决过程当中,通过不断追踪参数在源码中的用法,我很好地锻炼了我的解决问题的能力,也增加了我对pytorch库的一些了解。

注:我是刚学深度学习没多久的菜鸟,若有大佬发现不足之处,还请多多指正!

你可能感兴趣的:(pytorch,深度学习,人工智能)