生成对抗性网络(GAN)

转载:https://www.cnblogs.com/jiangxinyang/p/10156138.html

GAN的全称是 Generative Adversarial Networks,中文名称是生成对抗网络。原始的GAN是一种无监督学习方法,巧妙的利用“博弈”的思想来学习生成式模型。

1 GAN的原理

  GAN的基本原理很简单,其由两个网络组成,一个是生成网络G(Generator) ,另外一个是判别网络D(Discriminator)。它们的功能分别是:

  生成网络G:负责生成图片,它接收一个随机的噪声 zz,通过该噪声生成图片,将生成的图片记为 G(z)。

  判别网络D:负责判别一张图片是真实的图片还是由G生成的假的图片。其输入是一张图片 x,输出是0、1值,0代表图片是由G生成的,1代表是真实图片。

  在训练过程中,生成网路G的目标是尽量生成真实的图片去欺骗判别网络D。而判别网络D的目标就是尽量把G生成的图片和真实的图片区分开来。这样G和D就构成了一个动态的博弈过程。这是GAN的基本思想。

  在最理想的状态下,G可以生成足以“以假乱真”的图片 G(z)。对于D来说,它难以判断G生成的图片究竟是不是真实的,因此 D(G(z))=0.5(在这里我们输入的真实图片和生成的图片是各一半的)。此时得到的生成网络G就可以用来生成图片。

 

2 GAN损失函数

  从数学的角度上来看GAN,假设用于训练的真实图片数据是 xx,图片数据的分布为,生成网络G需要去学习到真实数据分布 。噪声 z的分布假设为,在这里 是已知的,而  是未知的。在理想的状态下G(z)的分布应该是尽可能接近,G将已知分布的z 变量映射到位置分布 x 变量上。

  根据交叉熵损失,可以构造下面的损失函数:

  

  其实从损失函数中可以看出和逻辑回归的损失函数基本一样,唯一不一样的是负例的概率值为 1−D(G(z))。

  损失函数中加号的前一半是训练数据中的真实样本,后一半是从已知的噪声分布中取的样本。下面对这个损失函数详细描述:

  1)整个式子有两项构成。 x表示真实图片,z表示输入G网络的噪声,而G(z) 表示G网络生成的图片。

  2)D(x) 表示D网络判断真实图片是否真实的概率 ,即 P(y=1|x)。而D(G(z))是D网络判断GG生成的图片是否真实的概率。

  3)G的目的:G应该希望自己生成的图片越真实越好。也就是说G希望 D(G(z))尽可能大,即P(G(z)=1|x),这时 V(D,G)尽可能小。

  4)D的目的:D的能力越强,D(x)就应该越大,D(G(x))应该越小(即假的图片都被识别为0)。因此D的目的和G的目的不同,D希望 V(D,G) 越大越好。

 

3 GAN建模流程

  在实际训练中,使用梯度下降法,对D和G交替做优化,具体步骤如下:

  1)从已知的噪声分布 pz(z)pz(z)中选取一些样本

    

  2)从训练数据中选出同样个数的真实图片

    

  3)设判别器D的参数为 θdθd,其损失函数的梯度为

   

  4)设生成器G的参数为 θgθg,其损失函数的梯度为

   

  在上面的步骤中,每更新一次D的参数,紧接着就更新一次G的参数,有时也可以在更新 kk 次D的参数,再更新一次G的参数。

你可能感兴趣的:(机器学习,神经网络)