StarGAN中的损失函数

学习笔记
论文地址:http://openaccess.thecvf.com/content_cvpr_2018/papers/Choi_StarGAN_Unified_Generative_CVPR_2018_paper.pdf

我们用生成器G把输入的图像x转换为目标域c,通过生成器G输出图像y,即G(x,c)→y。为了便生成器G可以学会灵活地转换输入图像,我们随机生成的目标域标签c。我们的判别器在源和域标签上产生概率分布为在这里插入图片描述

损失函数

一、对抗损失:
在这里插入图片描述
使用GAN的常规函数,生成器G通过输入图像x和目标域标签c生成图像G(x,c),而判别器D来辨别真实图像与生成图像。生成器G试图最小化该目标,而鉴别器D试图使其最大化。
实际操作中换成了WGAN的对抗损失:
StarGAN中的损失函数_第1张图片

二、类别损失:

判别器
在这里插入图片描述
我们的目地是将输入图像x转换为输出图像y,让正确地分类到目标域c。
Dcls(c’| x)代表判别器将真实样本归为原始标签c’ 的概率分布,判别器D的目标是最小化Loss 。图像x和和原域标签c’是由训练集给出的。

生成器
在这里插入图片描述
使生成器G生成假图片(x’)让它尽可能被判别器D分类成目标域c(比如愤怒),因此最小化Loss

重建损失
最小化损失并不能保证翻译的图像仅改变输入图片的与域相关的信息部分,而不改变其输入图像的内容。因此我们在此加上一个重建损失。
StarGAN中的损失函数_第2张图片
将G(x,c)和图片x的原始标签c’结合喂入到G中,将生成的图片和x计算1范数差异。(1范数就是向量中非零元素的绝对值之和,L1范数可以度量两个向量间的差异)
在这里插入图片描述

完整的loss

判别器:
在这里插入图片描述
生成器:
在这里插入图片描述
超参数:

#生成器
x_fake = self.G(x_real, c_trg)  # G(x, c)生成图像
out_src, out_cls = self.D(x_fake)  # 生成图像的判别
g_loss_fake = - torch.mean(out_src)  # 对抗损失
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)  # G分类损失
x_reconst = self.G(x_fake, c_org)  # G(G(x,c),c')
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))  # 重建损失
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls  # G完整的损失函数
# 判别器
x_fake = self.G(x_real, c_trg)  # G(x, c)生成图像
out_src, out_cls = self.D(x_real)  # 真实图像的判别
d_loss_real = - torch.mean(out_src)  

d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)  # D分类损失

out_src, out_cls = self.D(x_fake.detach())  # 生成图像的判别
d_loss_fake = torch.mean(out_src)  

#计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)  #  alpha是一个随机数 tensor([[[[ 0.7610]]]])
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)  # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
out_src, _ = self.D(x_hat)  # x_hat 表示梯度惩罚因子
d_loss_gp = self.gradient_penalty(out_src, x_hat)  # 梯度惩罚d_loss_gp 在0.9954~ 0.9956 波动
#d_loss_real + d_loss_fake  + self.lambda_gp * d_loss_gp是WGAN对抗损失
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp  # 总的损失函数

你可能感兴趣的:(深度学习,机器学习,python,计算机视觉,神经网络)