1. StarGAN 的简介
Pix2pix解决了两个领域之间匹配数据集之间的转换,然而在很多情况下匹配数据集很难获得,于是出现了CycleGAN。它可以实现两个领域之间非匹配数据集之间的转换, 然而这些转换每次只能在两个领域之间进行,当需要进行多个领域间的转换时,就需要训练多个网络,非常麻烦。2018年的CVPR上发表了一篇文章提出了StarGAN,它仅使用一个网络就实现多个领域之间的图像转换,而且图像转换的效果也比较好。
2. StarGAN 的网络结构:
注:上图中的两个生成器 G 是同一个生成器,整个StarGAN中只使用了一个生成器和一个判别器。
StarGAN实现了多个领域图像之间的转换,但是网络结构比CycleGAN 更简单,整个网络只包含一个生成器和一个判别器。从结构上StarGAN与ACGAN比较相似,生成器是输入除了图像之外还有目标领域的标签;判别器是输入除了图像之外还有相应的类别标签,而且判别器的输出除了判别图像真假之外还要对图像进行分类。
生成器的输入包含两个部分,一部分是输入图像imgs,大小为(batch_size, n_channel, cols, rows);一部分是目标领域的标签domain,大小为(batch_size, n_dim)。为了将这两部拼接,需要通过repeat操作来对domain进行扩展,将其扩展为(batch_size, n_dim, cols, rows),因此,生成器输入的大小为(batch_size, n_channel + n_dim, cols, rows),生成器的输出为(batch_size, n_channel, cols, rows)。判别器的输入为图像imgs,大小为(batch_size, n_channel, cols, rows),判别器的输出分为两部分,一部分是图像的真假判断,大小为(batch_size, 1, s1, s2),另一部分为图像的类别划分,大小为(batch_size, n_dim)。
3. StarGAN的损失函数
(1)对抗损失:即常规的生成对抗网络的损失的损失函数,判别器在努力地判别输入图像的真假,生成器在努力地生成假图像来欺骗判别器。
(2)分类损失:即将输入图像进行分类的损失。对于判别器D而言,需要将真实图像分到正确的类别中;对于生成器G而言,需要使得生成图像分到目标类别中。
对于判别器D: 其中,代表判别器将真实样本归为相应标签类别 的概率分布,判别器D的目标是最小化损失函数。
对于生成器G: 生成器希望生成数据能够被判别器判断为目标分类c, 因此生成器的目标是最小化损失函数。
(3)重建损失:为了确保生成数据能够很好地还原到原来的领域分类中,此处将原始图像和经过两次生成的图像的L1范数作为重建损失。
因此,StarGAN的生成器和判别器总的损失函数分别为:
生成器G损失函数:
gen_imgs = generator(imgs, sampled_c) # 生成图像,sampled_c 为随机生成的目标类标签
recov_imgs = generator(gen_imgs, labels) # 图像重建
fake_validity, pred_cls = discriminator(gen_imgs) # 生成图像的判别
loss_G_adv = -torch.mean(fake_validity) # 对抗损失
loss_G_cls = torch.nn.functional.binary_cross_entropy_with_logits(sampled_c, pred_cls, size_average=False) / sampled.size(0) # 分类损失
loss_G_rec = torch.nn.L1Loss(recov_imgs, imgs) # 重建损失
Loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec # 生成器总的损失
判别器D损失函数:
fake_imgs = generator(imgs, sampled_c) # 生成图像
real_validity, pred_cls = discriminator(imgs) # 真实图像的判别
fake_validity, _ = discriminator( fake_imgs.detach()) # 生成图像的判别
gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data) # 梯度惩罚
loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty # 对抗损失
loss_D_cls = torch.nn.functional.binary_cross_entropy_with_logits(labels, pred_cls, size_average=False) / sampled.size(0) # 分类损失
Loss_D = loss_D_adv + lambda_cls * loss_D_cls # 判别器总的损失