【GAN】CycleGAN学习--流程讲解

本博客讲解代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cyclegan

官方源码:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

官方源码和本博客讲解代码思路一致,本篇博客主要讲解整个流程。

但如果研究的话,推荐研究官方源码,其实也比较简单。

训练过程

1. Train Generators

【GAN】CycleGAN学习--流程讲解_第1张图片

【GAN】CycleGAN学习--流程讲解_第2张图片

loss函数:

loss_identity = (loss_id_A + loss_id_B) / 2
 
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
 
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
 
# Total loss
loss_G =    loss_GAN + \
            lambda_cyc * loss_cycle + \
            lambda_id * loss_identity

代码:

        # ------------------
        #  Train Generators
        # ------------------
 
        optimizer_G.zero_grad()
 
        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
 
        loss_identity = (loss_id_A + loss_id_B) / 2
 
        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
 
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
 
        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
 
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
 
        # Total loss
        loss_G =    loss_GAN + \
                    lambda_cyc * loss_cycle + \
                    lambda_id * loss_identity
 
        loss_G.backward()
        optimizer_G.step()

2. Train Discriminator A

【GAN】CycleGAN学习--流程讲解_第3张图片

loss 函数:

loss_D_A = (loss_real + loss_fake) / 2

代码:

        # -----------------------
        #  Train Discriminator A
        # -----------------------
 
        optimizer_D_A.zero_grad()
 
        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2
 
        loss_D_A.backward()
        optimizer_D_A.step()

3. Train Discriminator B

【GAN】CycleGAN学习--流程讲解_第4张图片

loss函数:

loss_D_B = (loss_real + loss_fake) / 2

代码:

        # -----------------------
        #  Train Discriminator B
        # -----------------------
 
        optimizer_D_B.zero_grad()
 
        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2
 
        loss_D_B.backward()
        optimizer_D_B.step()
 
        loss_D = (loss_D_A + loss_D_B) / 2

 

你可能感兴趣的:(Pytorch,论文笔记)