什么是生成式对抗神经网络GAN

当你看到以假乱真的图片或视频,看到风格迁移的图片或视频,你应当知道,其背后的机器学习技术是GAN!

GAN, generative adversarial network, 生成式对抗神经网络, 是生成模型的一种。

生成模型主要分两种,一种由输入数据,得到概率密度分布,另外一种,由输入数据,得到与输入数据相同分布的输出数据,GAN属于第二种。更多的关于生成模型的分类,见下图。

什么是生成式对抗神经网络GAN_第1张图片

GAN是怎样工作的呢?

GAN有两个网络,一个是生成器,希望生成同训练数据相同分布的样本,一个是判别器,希望将生成数据(fake)和训练数据(real)区分开来。

判别器希望real的output接近1,fake的output接近0,下面是判别器的损失函数的定义(只有一种):

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

生成器的损失函数有三种(零和游戏,非饱和游戏和最大似然游戏), 在非饱和游戏中,生成器希望fake经过判别器判别的output接近1,非饱和游戏中生成器的损失函数的定义如下:

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

下图中的D表示判别器(函数),G表示生成器(函数):

什么是生成式对抗神经网络GAN_第2张图片

需要注意,生成器的损失函数既依赖于生成器神经网络的参数,也依赖于判别器神经网络的参数;同样判别器的损失函数既依赖于判别器神经网络的参数,也依赖于生成器神经网络的参数。

训练GAN是一个博弈的过程,需要找到纳什均衡。

 

以上图片来自于 NIPS 2016 Tutorial: Generative Adversarial Networks by Ian Goodfellow

用来表示损失函数的定义的示例代码来自于 tensorflow tutorials https://www.tensorflow.org/tutorials/generative/dcgan

 

祖国翔,于上海

https://www.linkedin.com/in/guoxiang-zu/

你可能感兴趣的:(机器学习)