上一篇文章简单介绍了WGAN-GP的原理,本文来实现WGAN-GP的实战。
还是建议大家先读机器学习-生成对抗网络变种(三)
之前的博客写了DCGAN的实战代码,实际上在生成器和判别器网络构建方面都相差不大。
大家可以参照机器学习-生成对抗网络实战(二-1),进行对照学习。
目录
Part1判别器和生成器网络的设计:
自定义生成器类:
自定义判别器类:
class Generator(keras.Model):
def __init__(self):
super(Generator, self).__init__()
# z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
self.fc = layers.Dense(3*3*512)
self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
其本质作用还是利用转置卷积来实现图片的生成,但是前向传播略有不同。
def call(self, inputs, training=None):
# [z, 100] => [z, 3*3*512]
x = self.fc(inputs)
x = tf.reshape(x, [-1, 3, 3, 512])
x = tf.nn.leaky_relu(x)
x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = self.conv3(x)
x = tf.tanh(x)
return x
大家应该能注意到此时网络的激活函数除了最后一层都使用的leaky_relu激活函数,而最后一层使用的是tanh激活函数。这实际上是一系列的训练技巧,并不能从理论层面解释为什么这些激活函数比之前使用的relu效果好,大家记住就OK。
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
# [b, 64, 64, 3] => [b, 1]
self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
self.bn3 = layers.BatchNormalization()
# [b, h, w ,c] => [b, -1]
self.flatten = layers.Flatten()
self.fc = layers.Dense(1)
这一块和前面的DCGAN原理基本类似。最后卷积层提取完特征值之后打平输入全连接层,最后输出一个二分结果。
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv1(inputs))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
# [b, h, w, c] => [b, -1]
x = self.flatten(x)
# [b, -1] => [b, 1]
logits = self.fc(x)
return logits
此时使用的激活函数是leaky_relu大家注意区分,最后的二分输出此处不必激活优化,后面会自动优化。
代码来自于《TensorFlow深度学习》-龙龙老师