CycleGAN——loss解析及更改与实验

CycleGAN(五)loss解析及更改与实验

2019年04月01日 11:25:05 邢翔瑞

版权声明:转载注明出处:邢翔瑞的技术博客
https://blog.csdn.net/weixin_36474809
https://blog.csdn.net/weixin_36474809/article/details/88895136

目的:弄懂loss的定义位置及何更改。

目录

一、论文中loss定义及含义

1.1 论文中的loss

1.2 adversarial loss

1.3 cycle consistency loss

1.4 总体loss

1.5 idt loss

二、代码中loss定义

2.1 判别器D的loss

2.2 生成器G的loss

2.3 Idt loss

2.4 定义位置汇总

三、更改与实验

3.1 定义及更改位置

3.2 测试时会打出相应参数信息

四、训练中loss值常见变化

4.1 常见loss

4.2 运行及存储位置


一、论文中loss定义及含义

CycleGAN论文详解:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

1.1 论文中的loss

其过程包含了两种loss:

  • adversarial losses:尽可能让生成器生成的数据分布接近于真实的数据分布
  • cycle consistency losses: 防止生成器G与F相互矛盾,即两个生成器生成数据之后还能变换回来近似看成X->Y->X

1.2 adversarial loss

尽可能让生成器生成的数据接近于真实的数据分布:

与GAN一样,G用于实现X->Y, 训练应当尽可能让此G(X)接近于Y,判别器Dy用于判别样本的真假。与GAN的公式一样:

同理,对于F实现 Y->X,  

1.3 cycle consistency loss

用于让两个生成器生成的样本之间不要相互矛盾。

上一个adversarial loss只可以保证生成器生成的样本与真实样本同分布,但是我们希望对应的域之间的图像是一一对应的。即A-B-A还可以再迁移回来。

我们希望x -> G(x) -> F(G(x)) ≈ x,称作forward cycle consistency

同理,y -> F(y) -> G(F(y)) ≈ y, 称作 backward cycle consistency

为了尽可能保证consistency,我们设定相应的loss:

1.4 总体loss

即生成器G尽可能实现X到Y的迁移,生成器F尽可能实现Y到X的迁移,同时,希望两生成器的生成器是可以实现互逆,即相互迭代回到自身。(作者后面实验细节training datails之中,λ 取10 )

1.5 idt loss

有一个loss再论文主要部分没有提及,但是在application之中提及了,并且代码之中有涉及,是idt loss

cycle_gan_model.py之中对它的定义是这样:

parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

 
   
   
   
   

idt loss的定义在论文的application之中,防止input 与out put之间的color compostion过多。网络所有的loss的定义就是,reconstruction loss就是GAN loss和cycle consistency loss两个加在一起,GAN loss用于迁移类,cycle consistency loss用于尽量保留原图可以循环迁移。但是还有一个更直观的loss叫idt loss尽量的避免迁移过多。

二、代码中loss定义

models/cycle_gan_model.py

论文中并未提及idt_A以及idt_B的含义及作用。

2.1 判别器D的loss

运用真实样本作为正样本True,及G生成的样本作为负样本False,训练D


 
   
   
   
   
  1. def backward_D_basic(self, netD, real, fake):
  2. """Calculate GAN loss for the discriminator
  3. Parameters:
  4. netD (network) -- the discriminator D
  5. real (tensor array) -- real images
  6. fake (tensor array) -- images generated by a generator
  7. Return the discriminator loss.
  8. We also call loss_D.backward() to calculate the gradients.
  9. """
  10. # Real
  11. pred_real = netD(real)
  12. loss_D_real = self.criterionGAN(pred_real, True)
  13. # Fake
  14. pred_fake = netD(fake.detach())
  15. loss_D_fake = self.criterionGAN(pred_fake, False)
  16. # Combined loss and calculate gradients
  17. loss_D = (loss_D_real + loss_D_fake) * 0.5
  18. loss_D.backward()
  19. return loss_D

2.2 生成器G的loss

self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B


 
   
   
   
   
  1. # GAN loss D_A(G_A(A))
  2. self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
  3. # GAN loss D_B(G_B(B))
  4. self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
  5. # Forward cycle loss || G_B(G_A(A)) - A||
  6. self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
  7. # Backward cycle loss || G_A(G_B(B)) - B||
  8. self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
  9. # combined loss and calculate gradients
  10. self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
  11. self.loss_G.backward()

我们很容易理解,loss_G_A就是相应的GAN loss中的生成器G的项,loss_cycle_A就是cycle consistency loss中的项。

分别为GAN loss和L1 loss


 
   
   
   
   
  1. # define loss functions
  2. self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
  3. self.criterionCycle = torch.nn.L1Loss()
  4. self.criterionIdt = torch.nn.L1Loss()

2.3 Idt loss

idt loss是什么论文主要框架之中没有提及,cycle_gan_model.py之中对它的定义是这样:

parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

 
   
   
   
   

idt loss的定义再论文的application之中,防止input 与out put之间的color compostion过多。网络所有的loss的定义就是,reconstruction loss就是GAN loss和cycle consistency loss两个加在一起,GAN loss用于迁移类,cycle consistency loss用于尽量保留原图可以循环迁移。但是还有一个更直观的loss叫idt loss尽量的避免迁移过多。

2.4 定义位置汇总

  • GAN loss前无系数,
  • idt loss前面两个系数,lambda_B与lambda_idt
  • cycle loss前一个系数,ldmbda_B

 
   
   
   
   
  1. if lambda_idt > 0:
  2. # G_A should be identity if real_B is fed: ||G_A(B) - B||
  3. self.idt_A = self.netG_A(self.real_B)
  4. self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
  5. # G_B should be identity if real_A is fed: ||G_B(A) - A||
  6. self.idt_B = self.netG_B(self.real_A)
  7. self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
  8. else:
  9. self.loss_idt_A = 0
  10. self.loss_idt_B = 0
  11. # GAN loss D_A(G_A(A))
  12. self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
  13. # GAN loss D_B(G_B(B))
  14. self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
  15. # Forward cycle loss || G_B(G_A(A)) - A||
  16. self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
  17. # Backward cycle loss || G_A(G_B(B)) - B||
  18. self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

三、更改与实验

3.1 定义及更改位置

cycle_gan_model.py之中定义与更改


 
   
   
   
   
  1. class CycleGANModel(BaseModel):之中
  2. @staticmethod
  3. def modify_commandline_options(parser, is_train=True):

 
   
   
   
   
  1. For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
  2. A ( source domain), B (target domain).
  3. Generators: G_A: A -> B; G_B: B -> A.
  4. Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
  5. Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
  6. Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
  7. Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
  8. Dropout is not used in the original CycleGAN paper.
  9. "" "
  10. parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
  11. if is_train:
  12. parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
  13. parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
  14. parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
  15. return parser

这里设置相应的值,直接对default 进行更改即可,也可以输入命令行进行相应的更改,命令行后加上:

--lambda_A 10 --lambda_B 10
 
   
   
   
   

3.2 测试时会打出相应参数信息

四、训练中loss值常见变化

About loss curve

Unfortunately, the loss curve does not reveal much information in training GANs, and CycleGAN is no exception. To check whether the training has converged or not, we recommend periodically generating a few samples and looking at them.

作者给出,loss值对于实际的效果并没有影响,因为一个生成器和判别器的矛盾在于loss值,因此loss曲线并不能提现模型的性能。

4.1 常见loss

实际训练过程中可以根据loss值判断训练结果如何。几个值都是越小越好。

env/bin/python /home/xingxiangrui/pytorch-CycleGAN-and-pix2pix/train.py --dataroot /home/xingxiangrui/pytorch-CycleGAN-and-pix2pix/datasets/norText_2_cotton --name norText_2_cotton_cyclegan --model cycle_gan --no_html

运行成功:


 
   
   
   
   
  1. [xingxiangrui@yq01-gpu-yq-face -21 -5 ~]$ env/bin/python /home/xingxiangrui/pytorch-CycleGAN- and-pix2pix/train.py --dataroot /home/xingxiangrui/pytorch-CycleGAN- and-pix2pix/datasets/norText_2_cotton --name norText_2_cotton_cyclegan --model cycle_gan
  2. ----------------- Options ---------------
  3. batch_size: 1
  4. beta1: 0.5
  5. checkpoints_dir: ./checkpoints
  6. continue_train: False
  7. crop_size: 256
  8. dataroot: /home/xingxiangrui/pytorch-CycleGAN- and-pix2pix/datasets/norText_2_cotton [default: None]
  9. dataset_mode: unaligned
  10. direction: AtoB
  11. display_env: main
  12. display_freq: 400
  13. display_id: 1
  14. display_ncols: 4
  15. display_port: 8097
  16. display_server: http://localhost
  17. display_winsize: 256
  18. epoch: latest
  19. epoch_count: 1
  20. gan_mode: lsgan
  21. gpu_ids: 0
  22. init_gain: 0.02
  23. init_type: normal
  24. input_nc: 3
  25. isTrain: True [default: None]
  26. lambda_A: 10.0
  27. lambda_B: 10.0
  28. lambda_identity: 0.5
  29. load_iter: 0 [default: 0]
  30. load_size: 286
  31. lr: 0.0002
  32. lr_decay_iters: 50
  33. lr_policy: linear
  34. max_dataset_size: inf
  35. model: cycle_gan
  36. n_layers_D: 3
  37. name: norText_2_cotton_cyclegan [default: experiment_name]
  38. ndf: 64
  39. netD: basic
  40. netG: resnet_9blocks
  41. ngf: 64
  42. niter: 100
  43. niter_decay: 100
  44. no_dropout: True
  45. no_flip: False
  46. no_html: False
  47. norm: instance
  48. num_threads: 4
  49. output_nc: 3
  50. phase: train
  51. pool_size: 50
  52. preprocess: resize_and_crop
  53. print_freq: 100
  54. save_by_iter: False
  55. save_epoch_freq: 5
  56. save_latest_freq: 5000
  57. serial_batches: False
  58. suffix:
  59. update_html_freq: 1000
  60. verbose: False
  61. ----------------- End -------------------
  62. dataset [UnalignedDataset] was created
  63. The number of training images = 100
  64. initialize network with normal
  65. initialize network with normal
  66. initialize network with normal
  67. initialize network with normal
  68. model [CycleGANModel] was created
  69. ---------- Networks initialized -------------
  70. [Network G_A] Total number of parameters : 11.378 M
  71. [Network G_B] Total number of parameters : 11.378 M
  72. [Network D_A] Total number of parameters : 2.765 M
  73. [Network D_B] Total number of parameters : 2.765 M
  74. -----------------------------------------------
  75. 。。。
  76. create web directory ./checkpoints/norText_2_cotton_cyclegan/web...
  77. (epoch: 1, iters: 100, time: 0.896, data: 1.052) D_A: 0.384 G_A: 0.232 cycle_A: 1.791 idt_A: 0.739 D_B: 0.572 G_B: 0.620 cycle_B: 2.002 idt_B: 0.851
  78. End of epoch 1 / 200 Time Taken: 92 sec
  79. learning rate = 0.0002000
  80. (epoch: 2, iters: 100, time: 0.865, data: 0.214) D_A: 0.219 G_A: 0.304 cycle_A: 1.499 idt_A: 0.597 D_B: 0.305 G_B: 0.699 cycle_B: 1.118 idt_B: 0.711
  81. End of epoch 2 / 200 Time Taken: 87 sec
  82. learning rate = 0.0002000
  83. 。。。
  84. (epoch: 195, iters: 100, time: 0.865, data: 0.209) D_A: 0.016 G_A: 0.836 cycle_A: 0.584 idt_A: 0.168 D_B: 0.124 G_B: 0.346 cycle_B: 0.530 idt_B: 0.181
  85. saving the model at the end of epoch 195, iters 19500
  86. End of epoch 195 / 200 Time Taken: 88 sec
  87. learning rate = 0.0000119
  88. (epoch: 196, iters: 100, time: 1.197, data: 0.237) D_A: 0.036 G_A: 0.737 cycle_A: 0.590 idt_A: 0.133 D_B: 0.014 G_B: 0.271 cycle_B: 0.425 idt_B: 0.186
  89. End of epoch 196 / 200 Time Taken: 87 sec
  90. learning rate = 0.0000099
  91. (epoch: 197, iters: 100, time: 0.871, data: 0.218) D_A: 0.031 G_A: 0.756 cycle_A: 0.533 idt_A: 0.113 D_B: 0.037 G_B: 0.511 cycle_B: 0.370 idt_B: 0.160
  92. End of epoch 197 / 200 Time Taken: 86 sec
  93. learning rate = 0.0000079
  94. (epoch: 198, iters: 100, time: 0.856, data: 0.217) D_A: 0.063 G_A: 0.509 cycle_A: 0.634 idt_A: 0.123 D_B: 0.308 G_B: 0.492 cycle_B: 0.478 idt_B: 0.222
  95. End of epoch 198 / 200 Time Taken: 86 sec
  96. learning rate = 0.0000059
  97. (epoch: 199, iters: 100, time: 0.903, data: 0.203) D_A: 0.033 G_A: 0.981 cycle_A: 0.515 idt_A: 0.110 D_B: 0.111 G_B: 0.531 cycle_B: 0.381 idt_B: 0.167
  98. End of epoch 199 / 200 Time Taken: 86 sec
  99. learning rate = 0.0000040
  100. (epoch: 200, iters: 100, time: 1.200, data: 0.219) D_A: 0.017 G_A: 1.035 cycle_A: 0.613 idt_A: 0.106 D_B: 0.030 G_B: 0.726 cycle_B: 0.384 idt_B: 0.200
  101. saving the latest model (epoch 200, total_iters 20000)
  102. saving the model at the end of epoch 200, iters 20000
  103. End of epoch 200 / 200 Time Taken: 89 sec
  104. learning rate = 0.0000020

loss值越小,则训练越成功。最终D_A收敛于,一般看D的loss越小,则表明训练结果更好一些。

增大lambda为40之后,loss为:


 
   
   
   
   
  1. ================ Training Loss (Mon Apr 1 11: 46: 05 2019) ================
  2. (epoch: 1, iters: 100, time: 0.600, data: 0.176) D_A: 0.224 G_A: 0.485 cycle_A: 5.948 idt_A: 5.620 D_B: 0.410 G_B: 0.693 cycle_B: 9.939 idt_B: 2.715
  3. (epoch: 2, iters: 100, time: 0.604, data: 0.169) D_A: 0.251 G_A: 0.731 cycle_A: 5.778 idt_A: 4.086 D_B: 0.497 G_B: 1.078 cycle_B: 6.170 idt_B: 2.712
  4. (epoch: 3, iters: 100, time: 0.597, data: 0.163) D_A: 0.201 G_A: 0.586 cycle_A: 5.058 idt_A: 6.942 D_B: 0.219 G_B: 0.741 cycle_B: 13.409 idt_B: 2.129
  5. (epoch: 4, iters: 100, time: 0.838, data: 0.185) D_A: 0.123 G_A: 0.200 cycle_A: 5.216 idt_A: 1.311 D_B: 0.128 G_B: 0.740 cycle_B: 2.606 idt_B: 2.367
  6. (epoch: 5, iters: 100, time: 0.597, data: 0.146) D_A: 0.113 G_A: 0.472 cycle_A: 6.259 idt_A: 1.344 D_B: 0.258 G_B: 0.829 cycle_B: 3.239 idt_B: 2.951
  7. (epoch: 6, iters: 100, time: 0.598, data: 0.192) D_A: 0.088 G_A: 0.731 cycle_A: 3.720 idt_A: 1.364 D_B: 0.142 G_B: 2.097 cycle_B: 3.516 idt_B: 1.719
  8. (epoch: 7, iters: 100, time: 0.601, data: 0.170) D_A: 0.261 G_A: 0.691 cycle_A: 5.233 idt_A: 2.105 D_B: 0.788 G_B: 1.213 cycle_B: 4.088 idt_B: 2.316
  9. (epoch: 8, iters: 100, time: 0.986, data: 0.156) D_A: 0.108 G_A: 0.938 cycle_A: 4.672 idt_A: 3.141 D_B: 0.062 G_B: 0.983 cycle_B: 5.727 idt_B: 2.015
  10. (epoch: 9, iters: 100, time: 0.599, data: 0.175) D_A: 0.072 G_A: 0.883 cycle_A: 4.078 idt_A: 0.846 D_B: 0.057 G_B: 0.977 cycle_B: 2.139 idt_B: 1.904
  11. (epoch: 10, iters: 100, time: 0.599, data: 0.159) D_A: 0.142 G_A: 0.473 cycle_A: 4.346 idt_A: 1.358 D_B: 0.077 G_B: 0.753 cycle_B: 3.725 idt_B: 2.112
  12. 。。。
  13. (epoch: 195, iters: 100, time: 0.602, data: 0.164) D_A: 0.020 G_A: 0.723 cycle_A: 1.798 idt_A: 0.337 D_B: 0.034 G_B: 0.554 cycle_B: 1.263 idt_B: 0.681
  14. (epoch: 196, iters: 100, time: 0.919, data: 0.162) D_A: 0.009 G_A: 1.077 cycle_A: 1.783 idt_A: 0.317 D_B: 0.055 G_B: 0.737 cycle_B: 1.220 idt_B: 0.642
  15. (epoch: 197, iters: 100, time: 0.602, data: 0.155) D_A: 0.008 G_A: 1.063 cycle_A: 1.974 idt_A: 0.321 D_B: 0.149 G_B: 0.443 cycle_B: 1.102 idt_B: 0.725
  16. (epoch: 198, iters: 100, time: 0.599, data: 0.172) D_A: 0.007 G_A: 0.800 cycle_A: 1.763 idt_A: 0.450 D_B: 0.225 G_B: 0.888 cycle_B: 1.459 idt_B: 0.811
  17. (epoch: 199, iters: 100, time: 0.599, data: 0.153) D_A: 0.009 G_A: 1.097 cycle_A: 1.814 idt_A: 0.325 D_B: 0.082 G_B: 0.636 cycle_B: 1.103 idt_B: 0.709
  18. (epoch: 200, iters: 100, time: 0.946, data: 0.170) D_A: 0.009 G_A: 0.976 cycle_A: 1.923 idt_A: 0.319 D_B: 0.151 G_B: 0.608 cycle_B: 1.062 idt_B: 0.922

4.2 运行及存储位置

loss会进行相应运算并print出来,存于check points文件夹,之中模型的loss_log.txt之中,可以cat loss_log.txt打出loss信息。

注意,这个loss值是加了系数之后的loss,即乘了相应的lambda系数,之后打出的loss

你可能感兴趣的:(深度学习论文阅读)