GAN原理详解

GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:
GAN原理详解_第1张图片

  • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

数学公式:
在这里插入图片描述
这里主要理解是怎么训练网络的。

  • 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
  • D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
  • G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G。
  • D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)

还是不太理解,怎么做到网络的训练

假设我们已经有了G网络和D网络,以及real_data。那么我们怎么对GD训练呢。
再一次迭代中:
min ⁡ G \min\limits_{G} Gmin
1. 首先输入g_input(符合G输入的随机噪声),对G前向传播获得g_fake_data=G(g_input)
2. 然后输入至D网络对其进行鉴别dg_fake_describe = D(g_fake_data)
3. 此时损失定义为g_error = criterion(dg_fake_describe,1)。此时对这个损失进行反向传播,为了让这个减少这个损失,就会更新G网络的参数,是它生成的数据越来越接近真实。(记住这里这个损失也对D网络进行了反向传播,但是并不对D网络的参数进行更新)
max ⁡ D \max\limits_{D} Dmax
4. 获得真实标签输入d_real_input, 并对D网络前向传播获得d_real_descirbe = D(d_real_input)。计算其损失d_real_error= criterion(d_real_descirbe, 1)。这个损失反向传播的时候可以告诉D网络这是真实标签。
5. 获得G网络的生成的假数据g_fake_data ,并计算其输入到D网络的损失函数d_fake_error = criterion(D(g_input), 0),这个损失函数可以保证D网络能识别出假的图像。
(这里虽然计算G网络的反向传播梯度,但是并不对G网络进行参数进行更新)

直到最后D网络认为生成的假图像接近真实==(即D(g_fake_data)=0.5)==,表明网络训练成功。
这样就实现了两个网络的训练。

你可能感兴趣的:(深度学习)