GAN(生成对抗网络) 解释

GAN (生成对抗网络)是近几年深度学习中一个比较热门的研究方向,它的变种有上千种。

1.什么是GAN

GAN的英文全称是Generative Adversarial Network,中文名是生成对抗网络。它由两个部分组成,生成器和鉴别器(又称判别器),它们之间的关系可以用竞争或敌对关系来描述。

我们可以拿捕食者与被捕食者之间的例子来类似说明两者之间的关系。在生物进化的过程中,被捕食者会慢慢演化自己的特征,使自己越来越不容易被捕食者识别捕捉到,从而达到欺骗捕食者的目的;与此同时,捕食者也会随着被捕食者的演化来演化自己对被捕食者的识别,使自己越来越容易识别捕捉到捕食者。这样就可以达到两者共同进化的目的。生成器代表的是被捕食者,鉴别器代表的是捕食者。

2.GAN的原理

GAN的工作原理与上述例子还有略微的不同,GAN是已经知道最终鉴别的目标是什么,但不知道假目标是什么,它会对生成器所产生的假目标做惩罚并对真目标进行奖励,这样鉴别器就知道了不好的假目标与好的真目标具体是什么。生成器则是希望通过进化,产生比上一次更好的假目标,使鉴别器对自己的惩罚更小。以上是一个循环,在下一个循环中鉴别器通过学习上一个循环进化出的假目标和真目标,再次进化对假目标的惩罚,同时生成器再次进化,直到与真目标一致,结束进化。

GAN简单代码实现

#是一个卷积神经网络,变量名是D,其中一层构造方式如下。
w = tf.get_variable('w', [4, 4, c_dim, num_filter], 
initializer=tf.truncated_normal_initializer(stddev=stddev))
dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME')
biases = tf.get_variable('biases', [num_filter], 
        initializer=tf.constant_initializer(0.0))
bias = tf.nn.bias_add(dconv, biases)
dconv1 = tf.maximum(bias, leak*bias)

#是一个逆卷积神经网络,变量名是G,其中一层构造方式如下。
w = tf.get_variable('w', [4, 4, num_filter, num_filter*2], 
        initializer=tf.random_normal_initializer(stddev=stddev))
deconv = tf.nn.conv2d_transpose(gconv2, w, 
        output_shape=[batch_size, s2, s2, num_filter], 
        strides=[1, 2, 2, 1])
biases = tf.get_variable('biases', [num_filter],
initializer=tf.constant_initializer(0.0))
bias = tf.nn.bias_add(deconv, biases)
deconv1 = tf.nn.relu(bias, name=scope.name)

#的网络输入为一个维服从-1~1均匀分布的随机变量,这里取的是100.
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])
                .astype(np.float32)
#的网络输入是一个batch的64*64的图片,
#既可以是手写体数据也可以是的一个batch的输出。

#这个过程可以参考上图的a状态,判别曲线处于不够稳定的状态,
#两个网络都还没训练好。

#训练判别网络
#判别网络的损失函数由两部分组成,一部分是真实数据判别为1的损失,一部分是的输出self.G#判别为0的损失,需要优化的损失函数定义如下。

self.G = self.generator(self.z)
self.D, self.D_logits = self.discriminator(self.images)
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            self.D_logits, tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            self.D_logits_, tf.zeros_like(self.D_)))
self.d_loss = self.d_loss_real + self.d_loss_fake

#然后将一个batch的真实数据batch_images,和随机变量batch_z当做输入,执行session更新的参数。
##### update discriminator on real
d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, 
        beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars)
...
out1 = sess.run([d_optim], feed_dict={real_images: batch_images, 
        noise_images: batch_z})
        
#这一步可以对比图b,判别曲线渐渐趋于平稳。
#训练生成网络
#生成网络并没有一个独立的目标函数,它更新网络的梯度来源是判别网络对伪造图片求的梯度,
#并且是在设定伪造图片的label是1的情况下,保持判别网络不变,
#那么判别网络对伪造图片的梯度就是向着真实图片变化的方向。

self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            self.D_logits_, tf.ones_like(self.D_)))
#然后用同样的随机变量batch_z当做输入更新

g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) 
            .minimize(self.g_loss, var_list=self.g_vars)
out2 = sess.run([g_optim], feed_dict={noise_images:batch_z})

参考资料:
link1
link2

你可能感兴趣的:(GAN(生成对抗网络) 解释)