本文简要介绍了生成对抗网络(GAN)的原理,接下来通过tensorflow开发程序实现生成对抗网络(GAN),并且通过实现的GAN完成对等差数列的生成和识别。通过对设计思路和实现方案的介绍,本文可以辅助读者理解GAN的工作原理,并掌握实现方法。有了这样的基础,在面对工作中实际问题时可以将GAN纳入考虑,选择最合适的算法。
代码位置:
https://github.com/wangyaobupt/GAN
TensorFlow版本
>>> tf.version
‘1.1.0-rc2’
Generative Adversarial Nets[1][https://arxiv.org/pdf/1406.2661v1.pdf]是Ian J. Goodfellow等在2014年提出的一种训练模型的方法,此方法通过两个网络(生成网络G和分类网络D)对抗训练,得到符合预期目标的生成模型和分类模型。
要理解GAN的原理,上述论文是最好的教材。但考虑到原文首先是英文撰写,其次包含不少数学推导,新手上手并不容易。因此笔者这里班门弄斧,基于论文简单转述GAN的设计思想要点
GAN的目标,给定一个真实样本(本文也称之为ground truth)集合,训练出两个模型,一个能够从噪声信号生成尽可能像ground truth的样本;另一个能够判断给定样本是否是ground truth。两个模型详细介绍如下
从上述讨论可以看出,G网络和D网络是两个目标完全相反的网络,G网络尽其所能“伪造”出像真实样本的数据,D网络尽可能区分真实与伪造数据。GAN中所谓“对抗”的概念,即来源于此。
GAN的训练过程就是G和D两个网络互相对抗的过程,对抗的结果是G网络被训练到能够生成以假乱真的样本,即G网络从噪声输入得到了尽可能与真实样本相似的输出,或者说G学会了从噪声生成ground truth的方法;D网络也可以区分ground truth与其他样本,即D学会了区分ground truth与其他数据的方法。
参考文献
1. Goodfellow I J, Pougetabadie J, Mirza M, et al. Generative adversarial nets[C]. neural information processing systems, 2014: 2672-2680.
在开始设计神经网络之前,我们首先构造出预期GAN解决的问题。前述GAN论文中提出了一个从噪声学习正态分布的经典问题,读者如果在网络上搜索GAN的案例,除了图像识别,基本上只有这么一个问题和方案实现。
本文重新设计了一个与论文中不同的问题。问题描述如下
G网络:参考论文资料,我们选择多层全连接神经网络
D网络:由于要分辨的是等差数列,我们选择RNN作为D网络。
网络结构如下(下图是tensorboard生成的计算图):图中”G_net”表示G网络,”D_net”/”D_net_1”表示D网络,虽然图中D网络被分成了两份,但是其RNN参数是共享的,即图中正下方”rnn”这个单元。
G网络定义
# generative network
# use multi-layer percepton to generate time sequence from random noise
# input tensor must be in shape of (batch_size, self.seq_len)
def generator(self, inputTensor):
with tf.name_scope('G_net'):
gInputTensor = tf.identity(inputTensor, name='input')
# Multilayer percepton implementation
numNodesInEachLayer = 10
numLayers = 3
previous_output_tensor = gInputTensor
for layerIdx in range(numLayers):
activation,z = self.fullConnectedLayer(previous_output_tensor, numNodesInEachLayer, layerIdx)
previous_output_tensor = activation
g_logit = z
g_logit = tf.identity(g_logit, 'g_logit')
return g_logit
G网络损失函数
下面代码片段中self.d_logit_fake是D网络对G网络生成数据的判定结果。由于G网络的目标是尽可能骗过D网路,如果D网络对于G网络生成数据全部判为1(即TRUE),则损失最小,反之,损失最大。
g_loss_d = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_fake,
labels=tf.ones(shape=[self.batch_size_t,1])
),
name='g_loss_d'
)
D网络的定义
RNN+全连接输出层,无论是RNN还是全连接层都必须在对ground truth和G生成样本之间共享同一套参数
def discriminator(self, inputTensor,reuseParam):
with tf.name_scope('D_net'):
num_units_in_LSTMCell = 10
# RNN definition
with tf.variable_scope('d_rnn'):
lstmCell = tf.contrib.rnn.BasicLSTMCell(num_units_in_LSTMCell,reuse=reuseParam)
init_state = lstmCell.zero_state(self.batch_size_t, dtype=tf.float32)
raw_output, final_state = tf.nn.dynamic_rnn(lstmCell, inputTensor, initial_state=init_state)
rnn_output_list = tf.unstack(tf.transpose(raw_output, [1, 0, 2]), name='outList')
rnn_output_tensor = rnn_output_list[-1];
# Full connected network
numberOfInputDims = inputTensor.shape[1].value
numOfNodesInLayer = 1
if not reuseParam:
self.d_w = tf.Variable(initial_value=tf.random_normal([numberOfInputDims, numOfNodesInLayer]),
name=('dnet_w_1'))
self.d_b = tf.Variable(tf.zeros([1, numOfNodesInLayer]), name='dnet_b_1')
self.d_z = tf.matmul(rnn_output_tensor,self.d_w) + self.d_b
self.d_z = tf.identity(self.d_z, name='dnet_z_1')
d_sigmoid = tf.nn.sigmoid(self.d_z, name='dnet_a_1')
d_logit = self.d_z
d_logit = tf.identity(d_logit, 'd_net_logit')
return d_logit
D网络损失函数
D网络使用同一套参数分辨两种输入,一种是ground truth,另一种是G网络的输出。对于ground truth,训练目标为尽可能判为1,对于G网络的输出,训练目标为尽可能判为0,因此Loss函数定义如下
# For D-network, jduge ground truth to TRUE, jduge G-network output to FALSE,making loss low
d_loss_ground_truth = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_gnd_truth,
labels=tf.ones(shape=[self.batch_size_t,1])
),
name='d_loss_gnd'
)
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.d_logit_fake,
labels=tf.zeros(shape=[self.batch_size_t,1])
),
name='d_loss_fake'
)
d_loss = d_loss_ground_truth + d_loss_fake
对抗训练
对抗训练中,G网络Loss值只用来调整G网络参数,D网络Loss值只用来调整D网络参数
g_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G_net')
g_net_var_list = g_net_var_list + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='g_rnn')
self.train_g = tf.train.AdamOptimizer(self.lr_g).minimize(g_loss,var_list=g_net_var_list)
d_net_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D_net')
d_net_var_list = d_net_var_list + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='d_rnn')
self.train_d = tf.train.AdamOptimizer(self.lr_d).minimize(d_loss,var_list=d_net_var_list)
下图是训练过程中D网络对ground truth和G网络输出的分类正确率曲线
从图中可以看到3个阶段
上述3个阶段就体现出对抗训练的特点,G网络和D网络互为对手,互相提高对方的训练难度,最终得到符合预期的模型。
接下来再从数据上给一个直观的认识
Ground truth: 在公差为1的等差数列上加入stddev=0.3, mean=0的正态分布噪声后,得到的一组Ground Truth数据如下
[ 1.1539436 ]
[ 2.08863655]
[ 2.78491645]
[ 3.93027817]
[ 4.75851967]
[ 5.88655699]
[ 7.10540526]
[ 7.43159023]
[ 9.19373617]
[ 10.08779359]
训练开始前G网络的数据
基本无规律,和输入噪声分布接近
[ 1.15080559]
[ 0.66351247]
[-0.39484465]
[-0.41690648]
[ 0.29061955]
[ 0.06131642]
[-2.46439648]
[-1.53692639]
[-0.30550677]
[-0.89200932]
迭代100次之后G网络的输出
出现等差数列的端倪
[ -0.53692651]
[ 0.86063552]
[ 2.47294378]
[ 5.24512053]
[ 7.7618413 ]
[ 9.57867622]
[ 9.15039253]
[ 9.86567402]
[ 10.62975025]
[ 10.24322414]
迭代500次之后G网络的输出
除了最后一个元素,前9个元素已经基本符合预期
[ 1.09549832]
[ 2.21490908]
[ 2.95311546]
[ 4.06684017]
[ 4.96308947]
[ 6.03393888]
[ 6.89026165]
[ 7.93375683]
[ 8.63552094]
[ 9.07077026]
迭代1500次之后G网络的输出
已经足以以假乱真
[ 0.07186054]
[ 1.08289695]
[ 2.55904818]
[ 4.07374573]
[ 5.14763832]
[ 6.07010031]
[ 6.79585028]
[ 8.17086124]
[ 8.81297684]
[ 10.38190079]