机器学习-生成对抗网络WGAN-GP实战(四-1)

上一篇文章简单介绍了WGAN-GP的原理,本文来实现WGAN-GP的实战。

还是建议大家先读机器学习-生成对抗网络变种(三)

之前的博客写了DCGAN的实战代码,实际上在生成器和判别器网络构建方面都相差不大。

大家可以参照机器学习-生成对抗网络实战(二-1),进行对照学习。


目录

Part1判别器和生成器网络的设计:

自定义生成器类:

自定义判别器类:

Part1判别器和生成器网络的设计:

自定义生成器类:

class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3*3*512)

        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

其本质作用还是利用转置卷积来实现图片的生成,但是前向传播略有不同。

    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)

        return x

大家应该能注意到此时网络的激活函数除了最后一层都使用的leaky_relu激活函数,而最后一层使用的是tanh激活函数。这实际上是一系列的训练技巧,并不能从理论层面解释为什么这些激活函数比之前使用的relu效果好,大家记住就OK。


自定义判别器类:

class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        # [b, 64, 64, 3] => [b, 1]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')

        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

这一块和前面的DCGAN原理基本类似。最后卷积层提取完特征值之后打平输入全连接层,最后输出一个二分结果。

    def call(self, inputs, training=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)

        return logits

此时使用的激活函数是leaky_relu大家注意区分,最后的二分输出此处不必激活优化,后面会自动优化。


代码来自于《TensorFlow深度学习》-龙龙老师

机器学习-生成对抗网络WGAN-GP实战(四-1)_第1张图片

 

你可能感兴趣的:(深度学习,机器学习,神经网络,tensorflow,生成对抗网络)