Pytorch实现论文:对GAN的交替优化

简介

这次带来的是Closing the Gap Between Theory and Practice During Alternating Optimization for GANs,Gans交替优化中缩小理论与实践的差距这篇论文的一个核心代码在ACGAN模型上的效果测试,核心是修改了损失函数部分的计算。作者的实验是在StyleGAN上进行的。

论文简介

论文题目:Closing the Gap Between Theory and Practice During Alternating Optimization for GANs

论文出处:IEEE TRANSACTIONS ON NEURAL NETWORKS AND LEARNING SYSTEMS

论文摘要:融合高质量和多种样本是生成模型的主要目标。 尽管最近在生成对抗网络(GAN)方面取得了巨大进展,但模式崩溃仍然是一个开放的问题,并且减轻它将使生成器受益,以更好地捕获目标数据分布。 本文重新考虑了GANs的交替优化,这是一种经典的培训GAN的方法。 我们发现原始甘斯中提出的理论不能适应这种实用的解决方案。 在交替的优化方式下,香草损耗函数为发电机提供了不适当的目标。 该目标迫使生成器产生具有鉴别器的最高区分概率的输出,从而导致GAN中的模式崩溃。 为了解决这个问题,引入了一个新颖的损失功能,以使发电机适应交替的优化性质。 当通过提出的损耗函数更新发电机时,理论上优化了模型分布和目标分布之间的反向kullback -leibler差异,这鼓励模型学习目标分布。 广泛的实验结果表明,我们的方法可以始终如一地提高各种数据集和网络结构上的模型性能。

读后记录

 阅读论文之后,我个人认为作者提出的重要的一段在这里:

Pytorch实现论文:对GAN的交替优化_第1张图片

这相当于修改了鉴别器和生成器的损失函数,因此在这篇论文的基础上,我采用了ACGAN+以上损失设计尝试完成了代码测试。

在文中的实验中,超参数设置为1,因此我们直接在训练的过程中直接在损失计算上修改:

代码部分

训练

# 训练过程
for epoch in range(opt.nepochs):
    for i, (imgs, _) in enumerate(dataloader):
        batch_size = imgs.shape[0]

        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        real_imgs = Variable(imgs.type(FloatTensor))

        # 训练生成器
        optimizer_G.zero_grad()
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latentdim))))
        gen_imgs = generator(z)

        # 计算生成器的对抗损失
        validity = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        # 加入生成器正则化项
        # 调整生成图像尺寸
        if gen_imgs.size(2) != real_imgs.size(2) or gen_imgs.size(3) != real_imgs.size(3):
            gen_imgs = F.interpolate(gen_imgs, size=(real_imgs.size(2), real_imgs.size(3)), mode='bilinear',
                                     align_corners=False)

        g_loss += torch.mean((gen_imgs - real_imgs) ** 2)

        g_loss.backward()
        optimizer_G.step()

        # 训练判别器
        optimizer_D.zero_grad()

        # 计算真实图片的对抗损失
        real_pred = discriminator(real_imgs)
        d_real_loss = adversarial_loss(real_pred, valid)

        # 计算假图片的对抗损失
        fake_pred = discriminator(gen_imgs.detach())
        d_fake_loss = adversarial_loss(fake_pred, fake)

        # 判别器的总损失
        d_loss 

你可能感兴趣的:(GAN系列,生成对抗网络,计算机视觉,人工智能,pytorch,机器学习,深度学习)