pytorch-lightning 多个优化器的使用

一般在GAN或者类似的模型里,我们会有多个网络模型,每个网络模型都需要定义各自的优化器,如下所示:

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

在调用的时候,pytorch lightning 会根据返回参数,生成optimizer_idx参数,在训练的过程中,调用如下:

def training_step(self, batch, batch_idx, optimizer_idx):
        
       
        # train generator
        if optimizer_idx == 0:
            loss = 
            return loss

        # train discriminator
        if optimizer_idx == 1:
            loss = 
            return loss

training_step每次运行都会选择不同的优化器,在上述例子中,第一轮训练optimizer_idx==0,训练generator网络;第二轮训练optimizer_idx==1,训练discriminator网络;第三轮训练optimizer_idx==0。。。依次调用,直到完成所有训练

你可能感兴趣的:(pytorch,lightning,多个优化器,pytorch,神经网络,机器学习)