深度学习(33)——CycleGAN(2)

深度学习(33)——CycleGAN(2)

完整项目在在这里:欢迎造访

文章目录

  • 深度学习(33)——CycleGAN(2)
    • 1. Generator
    • 2. Discriminator
    • 3. fake pool
    • 4. loss定义
    • 5. 模型参数量
    • 6. debug 记录

数据格式:
深度学习(33)——CycleGAN(2)_第1张图片
深度学习(33)——CycleGAN(2)_第2张图片
深度学习(33)——CycleGAN(2)_第3张图片

1. Generator

self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
特征提取部分使用backbone是resnet(可选择的,可以换其他模型做backbone)

  • 上采样一共9个ResNet Block
    深度学习(33)——CycleGAN(2)_第4张图片
  • 下采样部分
    深度学习(33)——CycleGAN(2)_第5张图片

2. Discriminator

self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
在上一节说过,discriminator就是一个辨别真假的二分类模型,输入还是一张三通道的图像,最终判断这张图片是真是假。
深度学习(33)——CycleGAN(2)_第6张图片

3. fake pool

用于保存生成的fake image
self.fake_A_pool = ImagePool(opt.pool_size)

4. loss定义

self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.MSE
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
  • GANLoss 根据gan_mode定义,此处为MSELoss

5. 模型参数量

深度学习(33)——CycleGAN(2)_第7张图片

6. debug 记录

  • set_input(input): 得到real_A,real_B

  • optimize_parameters(): 计算loss做反向传播

    • forward(): 生成fake_A,fake_B,rec_A,rec_B

      • generatorA先根据real_A生成fake_B
      • generatorB使用fake_B生成rec_A
      • generatorB根据real_B生成fake_A
      • generatorA使用fake_A生成rec_B
    • backward_G():反向传播

      • 计算identity_loss:generatorA是输入real_A得到fake_B的,那现在输入real_B是不是也可以生成和real_B差不多,将这个生成的命名为idt_A,idt_A和real_B之间会存在identity_loss,同理idt_B和real_A之间也存在identity_loss
      • 计算generator_loss:generatorA生成的feak_B的loss,我们是希望feak_B是骗过discriminatorA的,所以希望discriminatorA认为是真的A,所以这里将fake_B与True做MSEloss,同理希望discriminatorB认为fake_A是真的B
      • 计算cycle_loss:real_A经过generatorA生成fake_B,fake_B经过经过generatorB返回生成rec_A,计算这样循环生成的A和真实A之间的loss,B也同理。
      • 最终的generator_loss是上面三者的和,因为有AB之分,所以一共有6项self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
    • backward_D_A(): 计算discriminator_A的loss

    • backward_D_B(): 计算discriminator_B的loss

  1. 当optimizer generator的时候discriminator设置为无梯度,不反向传播。 self.set_requires_grad([self.netD_A, self.netD_B], False)

就酱,欢迎提问讨论,886~

你可能感兴趣的:(深度学习,深度学习,人工智能,生成对抗网络)