生成对抗网络(GAN)

注:本文理论部分的截图全部来自李宏毅机器学习及其深层与结构化课件。
生成对抗网络有两个模型:判别模型和生成模型。
判别网络的目的:判别图像的真假,来自真样本集还是假样本集;
生成网络的目的:使造样本的能力尽可能强,判别网络无法判别是真样本还是假样本。

生成对抗网络理论部分

Generation

  1. 找到原数据集的数据分布Pdata(x),可以从中采样生成image;
  2. 设定一个有参数的θ的分布PG(x;θ),寻找使PG(x;θ)最接近Pdata(x)的θ;
     (1)从Pdata(x)中采样{x1,x2,...,xm};
     (2)计算PG(xi;θ);
     (3)产生以上样本的概率,L=PG(xi;θ),最大化L。
     其中,最大似然估计=最小KL松散度。
    最大似然估计推导

    为了得到与原样本相似的数据分布,参数为θ,在这里求使Pdata(x)与PG(x)散度最小的分布。

Generator

一个生成器G是一个网络,网络定义了概率分布PG.

PG与Pdata

Discriminator



训练:
将来自和的数据进行训练,最大化。
对于给定的G,最优的D最大化:

最大化V

最每个给定的x,最优D
去最大化:
每个x分别最大化

求导求最值

式子整理1

式子整理2

GAN关键含义

求得最大化V的判别器,本质上是实现真样本与假样本的二分类,然后寻找数据分布最小的生成器,以此迭代进行执行。上图中最后找到的最佳生成器为G3
步骤如下:

  1. 初始化生成器和判别器;
  2. 在每次训练迭代时:
      step1:固定生成器G,更新判别器D;
      step2:固定判别器D,更新生成器G。

算法流程

  1. 寻找最好的G最小化损失函数L(G)使用梯度下降进行求解,其中。
    相当于每个分段最大值
  2. 给定,寻找最大化中的。
  3. ,以此得到。
  4. 给定,寻找最大化中的。
    1. ,以此得到。
      ……
      相关问题

      本质上,是与的JS divergence.
      当得到新的G之后,可能存在D发生了变化的现象,可能此时D中的V(G,D)不是最大的情况。这里假设,在进行实现时,将D多进行几次迭代,train到底,找到最大的情况,将G进行小几步更新,以免出现上述不成立的现象。
      Discriminator leads the Generator.

代码实现

。。。未完待续

你可能感兴趣的:(生成对抗网络(GAN))