Pytorch学习(二十三)---- 不同layer用不同学习率(高级版本)

简易版本,
pytorch文档上,不同层用不同学习率,可以用这个。

optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

然而,这个太简单了,比如,我有一个特殊的层,这里取名为Adaptive Sigmoid的层,形式:

import torch
import torch.nn as nn

class Adaptive_Sigmoid(nn.Module):
    def __init__(self, alpha=1., beta=100.):
        super(Adaptive_Sigmoid, self).__init__()
        self.alpha = alpha
        self.beta = beta

        self.learning_alpha = nn.Parameter(torch.FloatTensor(1).fill_(self.alpha))
        self.learning_beta = nn.Parameter(torch.FloatTensor(1).fill_(self.beta))

    def forward(self, x):
        return 1.0/(1+torch.exp(-self.learning_alpha*(x - self.learning_beta)))

    def __repr__(self):
        return self.__class__.__name__+ '(' \
               'alpha ' + str(self.learning_alpha.data) \
               + ' ,beta: ' + str(self.learning_beta.data) + ')'

然后,在model中,我们用类似:

self.ada_sig = Adaptive_Sigmoid(opt.alpha, opt.beta)
self.base_layer_1 = ...
self.base_layer_2 = ...
...

这时候我们需要,对于base layers采用base_lr, 对于alpha的学习率,采用10*base_lr; beta的学习率,采用100*base_lr.
这时候,采用基本doc的方法就很麻烦。难道我对于所有base layers都要一层层写?而且,对于这种在在model里面的一个类中,有2个参数,用不同学习率的就也是很麻烦。

"""
这里'module'是因为,网络包了一层DataParallel
其实可以用
for n, p in self.netG.named_parameters():
    print(n, p.size())
来看所有参数层的名字和size
"""
alpha_list = ['module.ada_sig.learning_alpha']
beta_list = ['module.ada_sig.learning_beta']
alpha_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in alpha_list, model.named_parameters()))))
beta_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in beta_list, model.named_parameters()))))
base_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] not in (alpha_list + beta_list), model.named_parameters()))))
self.optimizer_G = torch.optim.SGD([
     {'params':base_params},
     {'params':alpha_params, 'lr':10*opt.lr},
     {'params':beta_params, 'lr':100*opt.lr},
     ],
lr = opt.lr, momentum=0.95, weight_decay=opt.weight_decay)

这种写法的简明之处,在于对于除了Adaptive Sigmoid层之外的所有层,都设置不同的学习率。

你可能感兴趣的:(PyTorch)