paper: https://arxiv.org/pdf/1703.10593.pdf
github: https://github.com/aitorzip/PyTorch-CycleGAN
6个生成器损失,2个判别器损失
1)6个生成器损失:
2) 2个判别器损失
对抗损失:
循环一致损失,即 X 经过生成器G_x后 得到Y,Y再过F_Y生成X,使得前后生成的X距离最小。
1) 前向一致损失
即从x 经过网络后还原为x的过程
X − > G ( x ) − > F ( G ( x ) ) = X X -> G(x) -> F(G(x)) =X X−>G(x)−>F(G(x))=X
2)反向一致损失
即y从经过网络后还原为y的过程
Y − > F ( y ) − > G ( F ( y ) ) = Y Y -> F(y) -> G(F(y)) =Y Y−>F(y)−>G(F(y))=Y
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
数据B经过生成器G_A,后生成的B,与原始B距离最小。
B-> G_A->B’ : 使得B 与B’距离最小
A-> F_B->A’ : 使得A 与A’距离最小
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A= self.L1Loss(self.idt_A, self.real_B)
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.L1Loss(self.idt_B, self.real_A)
生成器生成的数据,让判别器都判别为真
(备注:判别器输出不是一个值,而是一个矩阵,需要使判别器输出矩阵每一个值都接近1)
# GAN loss D_A(G_A(A))
self.loss_G_A = self.MSELoss(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterioMSELossGAN(self.netD_B(self.fake_A), True)
使得重构的A与原始A距离最近,使用L1Loss
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.L1Loss(self.rec_A, self.real_A)
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.L1Loss(self.rec_B, self.real_B)
上面6个生成器损失求和即为总的生成损失函数
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
判别器:使真实图片为判别为真,假图片判别为假
pred_real = netD(real)
loss_D_real = self.MSELoss(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.MSELoss(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5