最近在回顾过去所学的知识,生成对抗网络算是主要的研究方向之一,今天突然看到了很久之前的一份代码,拿出来复习一下生成对抗网络的基本思路,也给想要研究这个方向的同学一点参考。
生成对抗网络本质上就是生成器和判别器相互对抗。下面是判别器的代码
def discriminator(images, reuse=None):
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse) as scope:
# 卷积 + 激活 + 池化
d_w1 = tf.get_variable('d_w1',[5,5,1,32],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b1 = tf.get_variable('d_b1',[32],initializer=tf.constant_initializer(0))
d1 = tf.nn.conv2d(input=images,filter=d_w1,strides=[1,1,1,1],padding='SAME')
d1 = d1 + d_b1
d1 = tf.nn.relu(d1)
d1 = tf.nn.avg_pool(d1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
# 卷积 + 激活 + 池化
d_w2 = tf.get_variable('d_w2',[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b2 = tf.get_variable('d_b2',[64],initializer=tf.constant_initializer(0))
d2 = tf.nn.conv2d(input=d1,filter=d_w2,strides=[1,1,1,1],padding='SAME')
d2 = d2 + d_b2
d2 = tf.nn.relu(d2)
d2 = tf.nn.avg_pool(d2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
# 全连接 + 激活
d_w3 = tf.get_variable('d_w3',[7 * 7 * 64,1024],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b3 = tf.get_variable('d_b3',[1024],initializer=tf.constant_initializer(0))
d3 = tf.reshape(d2,[-1,7 * 7 * 64])
d3 = tf.matmul(d3,d_w3)
d3 = d3 + d_b3
d3 = tf.nn.relu(d3)
# 全连接
d_w4 = tf.get_variable('d_w4',[1024,1],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b4 = tf.get_variable('d_b4',[1],initializer=tf.constant_initializer(0))
d4 = tf.matmul(d3,d_w4) + d_b4
# 最后输出一个非尺度化的值
return d4
代码中所写的判别器由两个卷积层和两个全连接层组成,激活函数使用的是relu函数,由tf.nn.relu实现,使用的池化方式是均值池化,而不是最大池化,由tf.nn.avg_pool实现。
还要注意一下的而是tf.truncated_normal_initializer函数和tf.random_normal 函数的区别。
def generator(z, batch_size, z_dim, reuse=False):
'''接收特征向量z,由z生成图片'''
with tf.variable_scope(tf.get_variable_scope(),reuse=reuse):
# 全连接 + 批正则化 + 激活
# z_dim -> 3136 -> 56*56*1
g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02))
g1 = tf.matmul(z, g_w1) + g_b1
g1 = tf.reshape(g1, [-1, 56, 56, 1])
g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='bn1')
g1 = tf.nn.relu(g1)
# 卷积 + 批正则化 + 激活
g_w2 = tf.get_variable('g_w2',[3,3,1,z_dim / 2],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b2 = tf.get_variable('g_b2',[z_dim / 2],initializer=tf.truncated_normal_initializer(stddev=0.02))
g2 = tf.nn.conv2d(g1,g_w2,strides=[1,2,2,1],padding='SAME')
g2 = g2 + g_b2
g2 = tf.contrib.layers.batch_norm(g2,epsilon=1e-5,scope='bn2')
g2 = tf.nn.relu(g2)
g2 = tf.image.resize_images(g2,[56,56])
# 卷积 + 批正则化 + 激活
g_w3 = tf.get_variable('g_w3',[3,3,z_dim / 2,z_dim / 4],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b3 = tf.get_variable('g_b3',[z_dim / 4],initializer=tf.truncated_normal_initializer(stddev=0.02))
g3 = tf.nn.conv2d(g2,g_w3,strides=[1,2,2,1],padding='SAME')
g3 = g3 + g_b3
g3 = tf.contrib.layers.batch_norm(g3,epsilon=1e-5,scope='bn3')
g3 = tf.nn.relu(g3)
g3 = tf.image.resize_images(g3,[56,56])
# 卷积 + 激活
g_w4 = tf.get_variable('g_w4',[1,1,z_dim / 4,1],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b4 = tf.get_variable('g_b4',[1],initializer=tf.truncated_normal_initializer(stddev=0.02))
g4 = tf.nn.conv2d(g3,g_w4,strides=[1,2,2,1],padding='SAME')
g4 = g4 + g_b4
g4 = tf.sigmoid(g4)
# 输出g4的维度: batch_size x 28 x 28 x 1
return g4
接下来是生成器generator,中间层的激活函数都是relu函数,最后输出使用的是sigmoid函数