StarGAN-多领域间的转换

1. StarGAN 的简介

        Pix2pix解决了两个领域之间匹配数据集之间的转换,然而在很多情况下匹配数据集很难获得,于是出现了CycleGAN。它可以实现两个领域之间非匹配数据集之间的转换, 然而这些转换每次只能在两个领域之间进行,当需要进行多个领域间的转换时,就需要训练多个网络,非常麻烦。2018年的CVPR上发表了一篇文章提出了StarGAN,它仅使用一个网络就实现多个领域之间的图像转换,而且图像转换的效果也比较好。


2. StarGAN 的网络结构:

StarGAN的结构拆分

注:上图中的两个生成器 G 是同一个生成器,整个StarGAN中只使用了一个生成器和一个判别器。

        StarGAN实现了多个领域图像之间的转换,但是网络结构比CycleGAN 更简单,整个网络只包含一个生成器和一个判别器。从结构上StarGAN与ACGAN比较相似,生成器是输入除了图像之外还有目标领域的标签;判别器是输入除了图像之外还有相应的类别标签,而且判别器的输出除了判别图像真假之外还要对图像进行分类。

        生成器的输入包含两个部分,一部分是输入图像imgs,大小为(batch_size, n_channel, cols, rows);一部分是目标领域的标签domain,大小为(batch_size, n_dim)。为了将这两部拼接,需要通过repeat操作来对domain进行扩展,将其扩展为(batch_size, n_dim, cols, rows),因此,生成器输入的大小为(batch_size, n_channel + n_dim, cols, rows),生成器的输出为(batch_size, n_channel, cols, rows)。判别器的输入为图像imgs,大小为(batch_size, n_channel, cols, rows),判别器的输出分为两部分,一部分是图像的真假判断,大小为(batch_size, 1, s1, s2),另一部分为图像的类别划分,大小为(batch_size, n_dim)。


3. StarGAN的损失函数

(1)对抗损失:即常规的生成对抗网络的损失的损失函数,判别器在努力地判别输入图像的真假,生成器在努力地生成假图像来欺骗判别器。

                

(2)分类损失:即将输入图像进行分类的损失。对于判别器D而言,需要将真实图像分到正确的类别中;对于生成器G而言,需要使得生成图像分到目标类别中。

对于判别器D:             其中,代表判别器将真实样本归为相应标签类别 的概率分布,判别器D的目标是最小化损失函数。

对于生成器G:        生成器希望生成数据能够被判别器判断为目标分类c, 因此生成器的目标是最小化损失函数。

(3)重建损失:为了确保生成数据能够很好地还原到原来的领域分类中,此处将原始图像和经过两次生成的图像的L1范数作为重建损失。

                

因此,StarGAN的生成器和判别器总的损失函数分别为:

生成器G损失函数: 

gen_imgs = generator(imgs, sampled_c)       #  生成图像,sampled_c 为随机生成的目标类标签

recov_imgs = generator(gen_imgs, labels)          #   图像重建

fake_validity, pred_cls = discriminator(gen_imgs)         #  生成图像的判别

loss_G_adv = -torch.mean(fake_validity)          # 对抗损失

loss_G_cls = torch.nn.functional.binary_cross_entropy_with_logits(sampled_c, pred_cls, size_average=False) / sampled.size(0)        #  分类损失

loss_G_rec = torch.nn.L1Loss(recov_imgs, imgs)     # 重建损失

Loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec     #  生成器总的损失

判别器D损失函数: 

fake_imgs = generator(imgs, sampled_c)      #  生成图像

real_validity, pred_cls = discriminator(imgs)      # 真实图像的判别

fake_validity, _ = discriminator( fake_imgs.detach())     #  生成图像的判别

gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data)       # 梯度惩罚

loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty       # 对抗损失

loss_D_cls = torch.nn.functional.binary_cross_entropy_with_logits(labels, pred_cls, size_average=False) / sampled.size(0)           # 分类损失

Loss_D = loss_D_adv + lambda_cls * loss_D_cls     # 判别器总的损失

你可能感兴趣的:(StarGAN-多领域间的转换)