第十三章 生成对抗网络

在生成对抗网络(Generative Adversarial Network)之前,VAE被认为是理论完美,实现简单,使用神经网络训练起来很稳定,生成的图片逼近度也较高,但是人类还是可以轻易区分。

但是 Ian Goodfellow 提出了生成对抗网络,最新的算法在图片生成上的效果甚至达到了肉眼难辨的程度。

13.1 博弈学习实例

GAN 网络借鉴了博弈学习的思想,分别设立了两个子网络:负责生成样本的生成器G,负责鉴别真伪的鉴别器D。

鉴别器D通过观察真实的样本和生成器G产生的样本之间的区别,学会如何鉴别真假(真实样本为真。生成器G产生的样本为假)

生成器G希望产生的样本能够骗过鉴别器D,因此生成器G通过优化自身的参数,尝试使得自己产生的样本在鉴别器D中判别为真

生成器G和鉴别器D相互博弈,直至达到平衡点(此时生成器G生成的样本非常逼真,使得鉴别器D真假难分)。

在原始的GAN论文中,Ian Goodfellow 对GAN有一个形象的比喻:
生成器网络G的功能是产生一系列非常逼真的假钞试图欺骗鉴别器D,而鉴别器D通过学习真钞和生成器G产生的假钞来掌握钞票的鉴别方法;这两个网络在相互博弈的过程中间同步提升,直到生成器G产生的假钞非常逼真,连鉴别器D都真假难辨。

13.2 GAN 原理

这部分介绍生成对抗网络的网络结构和训练方法

13.2.1 网络结构

生成对抗网络包含了两个子网络:

  • 生成网络(负责学习样本的真实分布)
  • 判别网络(负责将生成网络采样的样本与真实样本区分开来)

13.2.1.1 生成网络

从先验分布中采样隐藏变量,通过生成网络参数化的获得生成样本(其中隐藏变量的先验分布可以假设属于某中已知的分布)

可以用深度神经网络来参数话。

e.g. 从均匀分布中采样出隐藏变量,经过多层转置卷积层网络参数化的分布中采样出样本

13.2.1.2 判别网络

判别网络和普通的二分类网络功能类似,它接受输入样本,包含了采样自真实数据分布的样本,也同时包含了采样自生成网络的假样本。

判别网络输出为属于真实样本的概率,把所有真实样本的标签标注为真(1),所有生成网络产生的样本标注为假(0),通过最小化判别网络预测值与标签之间的误差来优化判断网络参数。

13.2.2 网络训练

GAN 博弈学习的思想体现在它的训练方式上,由于生成器G和判别器D的优化目标不一样,不能和之前的网络训练一样,只采用一个损失函数。

对于判别网络D,它的目标是能够很好地分辨出真样本与假样本。

以图片生成为例,它的目标是最小化图片的预测值和真实值之间的交叉熵损失函数:

其中,代表真实样本在判别网络的输出;代表生成样本在判别网络的输出,为的标签,由于真实样本标注为真(),为的标签,由于生成样本标注为假()。

根据二分类问题的交叉熵损失函数定义:

判别网络的优化目标是:

把问题转换为,并写成期望形式:

希望样本在判别网络的输出越接近真实标签越好,意味着,在训练生成网络时,希望判别网络的输出越逼近1越好

交叉熵损失函数为:

把问题转换成最大化问题,并写出期望形式:

再次等价转化为:

13.2.3 统一目标函数

把判别网络的目标和生成网络的目标合并,写成min-max 博弈形式:
\begin{split} \min_{\phi}\max_{\theta}\mathcal{L}(D,G)=&\mathbb{E}_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r})+\mathbb{E}_{x_{f}\sim p_{g}}log\;(1-D_{\theta}(x_{f})) \\ =&\mathbb{E}_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r})+\mathbb{E}_{x_{f}\sim p_{g}}log\;(1-D_{\theta}(G_{\phi}(z))) \end{split}

13.3 DCGAN 实战

import os
import glob
import tensorflow as tf
import numpy as np
from PIL import Image
resize = 64
batch_size = 64
def preprocess(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img,[resize,resize])
    img = tf.clip_by_value(img,0,255)
    img = img / 127.5 - 1
    
    return img
img_paths = glob.glob('D:/faces/*.jpg')
dataset = tf.data.Dataset.from_tensor_slices(img_paths)
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch_size)
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        filter = 64
        
        self.conv1 = tf.keras.layers.Conv2DTranspose(filter * 8, 
                                    4, 1, padding = 'valid',
                                    use_bias = False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2DTranspose(filter * 4,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.conv3 = tf.keras.layers.Conv2DTranspose(filter * 2,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        self.conv4 = tf.keras.layers.Conv2DTranspose(filter * 1,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        self.conv5 = tf.keras.layers.Conv2DTranspose(3, 4, 2,
                                    padding = 'same',
                                    use_bias = False)
        
    def call(self, inputs, training = None):
        x = inputs
        x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
        x = tf.nn.relu(x)
        x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
        x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
        x = self.conv5(x)
        x = tf.tanh(x)
        
        return x
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        filter = 64
        
        self.conv1 = tf.keras.layers.Conv2D(filter, 4, 2, 
                                    padding = 'valid',
                                    use_bias = False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2D(filter * 2, 4, 2,
                                    padding = 'valid',
                                    use_bias = False)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.conv3 = tf.keras.layers.Conv2D(filter * 4, 4, 2,
                                    padding = 'valid',
                                    use_bias = False)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        self.conv4 = tf.keras.layers.Conv2D(filter * 8, 3, 1,
                                    padding = 'same',
                                    use_bias = False)
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        self.conv5 = tf.keras.layers.Conv2D(filter * 16, 3, 1, 
                                    padding = 'same',
                                    use_bias = False)
        self.bn5 = tf.keras.layers.BatchNormalization()
        
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.fc = tf.keras.layers.Dense(1)
        
    def call(self, inputs, training = None):
        x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training)) 
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))

        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))

        x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
           
        x = self.pool(x)
            
        x = self.flatten(x)
            
        logits = self.fc(x)
            
        return logits   
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)
    
    d_loss_real = celoss_ones(d_real_logits)
    
    d_loss_fake = celoss_zeros(d_fake_logits)
    
    loss = d_loss_fake + d_loss_real
    
    return loss

def celoss_ones(logits):
    y = tf.ones_like(logits)
    loss = tf.keras.losses.binary_crossentropy(y, logits,
                                              from_logits = True)
    
    return tf.reduce_mean(loss)

def celoss_zeros(logits):
    y = tf.zeros_like(logits)
    loss = tf.keras.losses.binary_crossentropy(y, logits,
                                              from_logits=True)
    
    return tf.reduce_mean(loss)

def g_loss_fn(generator, discriminator, batch_z, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = celoss_ones(d_fake_logits)
    
    return loss
generator = Generator()
generator.build(input_shape = (4, 100))
discriminator = Discriminator()
discriminator.build(input_shape=(4, 64, 64, 3))
g_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002,
                                      beta_1 = 0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002,
                                      beta_1 = 0.5)
for epoch in range(300):
    for _ in range(5):
        batch_z = tf.random.normal([batch_size, 100])
        batch_x = next(iter(dataset))
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator,discriminator, batch_z, batch_x, True)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    batch_z = tf.random.normal([batch_size, 100])
    batch_x = next(iter(dataset))
    with tf.GradientTape() as tape:
        g_loss = g_loss_fn(generator,discriminator, batch_z, True)
        
    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    
    if epoch % 100 ==0:
        print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))

13.4 GAN 变种

原始GAN 模型在图片生成效果并不突出,和VAE差别不明显,并没有展现出它强大的分布逼近能力。但是由于GAN在理论方面教新颖,实现方面也有很多可以改进的地方,因此激发了学术界的研究兴趣。

13.4.1 DCGAN

DCGAN (2015) 提出了使用转置卷积层实现的生成网络,普通卷积层来实现的判别网络,来降低网络的参数量,同时图片的生成效果也大幅提升。

13.4.2 InfoGAN

InfoGAN (2016) 尝试使用无监督的方式去学习输入的可接受隐向量表示方法。

13.4.3 CycleGAN

CycleGAN (2017)是华人学者朱俊彦提出的无监督的方式进行图片转换的算法,并且其算法清晰简单,实验效果完成的非常好

CycleGAN 基本的假设是,如果由图片转换到图片,再从图片转换到,那么应该和是同一张图片,因此出了设立标准的GAN 损失项外,CycleGAN 还增设了循环一致性损失

13.4.4 WGAN/WGAN-GP

GAN 的训练很容易出现训练不收敛和模式崩塌的现象。

WGAN (2017)从理论层面分析了原始的 GAN 使用 JS 散度存在的缺陷,并提出了可以用 Wasserstein 距离来解决这个问题。

WGAN-GP(2017),作者提出了通过添加梯度惩罚项目,从工程层面很好的实现了 WGAN 算法,并且实验性证实了 WGAN 训练稳定的优点。

13.4.5 Equal GAN

Google Brain 的几位研究员在2018年提出了另一个观点:没有证据表明我们测试的GAN变种算法一直持续地比最初始的GAN要好。论文中对这些GAN变种进行了相对公平,全面的比较,在有足够计算资源的情况下,几乎所有的GAN变种都能达到相似的性能(FID分数)。

13.4.6 Self-Attention GAN

Self-Attention GAN (SAGAN)借鉴了 Attention 机制,提出了基于自我注意力机制的 GAN 变种。SAGAN(2019)把图片的逼真图指标,Inception score(从最好的36.8 提升到52.52),Frechet Inception distance(从27.62降到18.65)。

13.4.7 BigGAN

在 SAGAN 的基础上,BigGAN(2019)尝试将GAN的训练扩展到大规模上,利用正交正则化等技巧保证训练过程的稳定性。

BigGAN 的意义在于:GAN 网络的训练同样可以从大数据,大算力中间受益。

其把图片的逼真图指标,Inception score(提升到166.5),Frechet Inception distance(从27.62降到18.65)。

13.5 纳什均衡

从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么状态。

探索以下两个问题:

  • 固定,会收敛到什么最优状态?
  • 在达到最优状态后,会收敛到什么状态?

13.5.1 判别器状态

回归GAN的损失函数:


GAN损失函数

对于判别器D,优化的目标是最大化函数,需要找出:

公式

公式

不是很能理解这里。

13.5.2 生成器状态

JS 散度,它定义为KL散度的组合:


推导
推导
推导
推导

13.6 GAN 训练难题

GAN 网络训练困难的问题,主要体现在:

13.6.1 超参数敏感

超参数敏感是指网络的结构,学习率,初始化状态等超参数对网络的训练过程影响较大,微量的超参数调整可能导致网络的训练结果截然不同。

为了能较好地训练GAN网络,DCGAN 论文作者提出了不使用 Pooling 层,多使用 Batch Normalization层,不使用全连接层,生成网络中激活函数使用ReLU,最后一层使用tanh,判别网络激活函数室友LeakyLeLU等一系列经验性技巧。

13.6.2 模式崩塌

模式崩塌是指模型生成的样本单一,多样性很差。

由于判别器只能鉴别单个样本是否采样直真实样本,并没有对样本多样性进行显示约束,导致生成模型可能倾向于生成真实分布的部分区间中的少量高质量样本,以此来在判别器的输出中获得较高的概率值,

但是,我们希望生成网络能够逼近真实的分布,而不是真实分布中的某部分。

13.7 WGAN 原理

WGAN 作者提出是因为 JS 散度在不重叠的分布的梯度曲面是恒定为0的,当分布不重叠时,JS散度的梯度始终为0,从而导致此时GAN的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。

要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布不重叠时,也能平滑反映分布之间的距离变换。

13.7.2 EM 距离

WGAN 论文中发现了JS散度导致GAN训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein Distance,也叫做推土机距离(EM 距离),它表示了从一个分布变换到另一个分布的最小代价,定义为:


公式

image.png

13.7.3 WGAN-GP

你可能感兴趣的:(第十三章 生成对抗网络)