生成对抗式网络GAN 的 loss

GAN同时要训练一个生成网络(Generator)和一个判别网络(Discriminator),前者输入一个noise变量 z ,输出一个伪图片数据 G(z;θg) ,后者输入一个图片(real image)以及伪图片(fake image)数据 x ,输出一个表示该输入是自然图片或者伪造图片的二分类置信度 D(x;θd) ,理想情况下,判别器 D 需要尽可能准确的判断输入数据到底是一个真实的图片还是某种伪造的图片,而生成器 G 又需要尽最大可能去欺骗 D ,让 D 把自己产生的伪造图片全部判断成真实的图片。
根据上述训练过程的描述,我们可以定义一个损失函数:

Loss=1mmi=1[logD(xi)+log(1D(G(zi)))]

其中 xi , zi 分别是真实的图片数据以及noise变量。
而优化目标则是:

minGmaxDLoss

不过需要注意的一点是,实际训练过程中并不是直接在上述优化目标上对 θd , θg 计算梯度,而是分成几个步骤:

训练判别器即更新 θd :循环 k 次,每次准备一组real image数据 x=x1,x2,,xm 和一组fake image数据 z=z1,z2,,zm ,计算
θd1mmi=1[logD(xi)+log(1D(G(zi)))]
然后梯度上升法更新 θd
训练生成器即更新 θg :准备一组fake image数据 z=z1,z2,,zm ,计算
θg1mmi=1log(1D(G(zi)))
然后梯度下降法更新 θg
可以看出,第一步内部有一个 k 层的循环,某种程度上可以认为是因为我们的训练首先要保证判别器足够好然后才能开始训练生成器,否则对应的生成器也没有什么作用,然后第二步求提督时只计算fake image那部分数据,这是因为real image不由生成器产生,因此对应的梯度为0。

你可能感兴趣的:(DL)