本篇blog的内容基于原始论文WassersteinGAN和《生成对抗网络入门指南》第五章。
WGAN前作:TOWARDS PRINCIPLED METHODS FOR TRAINING GENERATIVE ADVERSARIAL NETWORKS
关于GAN的一些问题:训练的不稳定性;理论上,应该先把判别器训练到足够好,但是实际操作发现反而更难去优化生成器。
原始GAN中判别器要最小化下面损失函数
假定x固定,和进行求导:
对于形式如下:
然而GAN训练有一个trick,就是别把判别器训练得太好,否则在实验中生成器会完全学不动(loss降不下去)
先了解一些理论知识。从理论和经验上说,真实数据的分布通常是一个低维度流形(manifold)。流形是数据虽然分布在高维度空间里,但是实际上数据并不具备高维度特性,二世嵌入在高维度的低维度空间里。
现在再回顾之前的生成器,要将低维度的空间Z映射到与真实数据相同的高维度空间上,就是希望我们生成的低维度的manifold能高度逼近真实数据的manifold。
JS散度和KL散度相似,设定,JS散度公式为:
把KL公式代入展开:
可以继续写成
根据原始GAN定义的判别器loss,我们可以得到最优判别器的形式;而在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化和之间的JS散度。
如果真实数据和生成数据在空间上完全不相交,可以得到一个完美的判别器划分真实数据和生成数据。实际生活中,生成空间和真实空间完美重合的概率是十分低的,所以大部分情况我们都能找到一个完美的判别器进行划分。也就会导致在网络训练的反向传播中,梯度更新几乎为0,网络难以学到东西。
根据散度公式发现只要生成数据和真实数据没有交集,JS散度始终未常数log2,而他们之间KL散度永远为正无穷。
但是与不重叠或重叠部分可忽略的可能性有多大?不严谨的答案是:非常大。比较严谨的答案是:当与的支撑集(support)是高维空间中的低维流形(manifold)时,与重叠部分测度(measure)为0的概率为1。
不用被奇怪的术语吓得关掉页面,虽然论文给出的是严格的数学表述,但是直观上其实很容易理解。首先简单介绍一下这几个概念:
有了这些理论分析,原始GAN不稳定的原因就彻底清楚了:判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练。
所以有时候尽管生成器表现很好了,与真实数据逼近,但是散度表现依然很差。所以我们更换一种合适的方法计算相似度距离。
1. 这里我们看到GAN很容易发生梯度消失,在训练1/10/25个epoch都很快就迭代掉下了5个数量级。
但是,很多时候还会导致网络更新不稳定的情况。
2. 而且从上图发现曲线噪声也很大。
但是,当生成数据与真实数据本身相似度距离较远的话,添加噪声的方案可能就无效了。
对于真实数据分布与生成数据分布,给出以下几种分布距离公式:
总变差距离(total variation distance)和KL散度
然后是JS散度
最后是本篇主角Wasserstein距离(EM距离):
这里可以用一个例子来形容,有两堆泥土,每一堆有 n 个位置,标号从1~n。第一堆泥土的第 i 个位置有 克泥土,第二堆泥土的第 i 个位置有 克泥土。小埃可以在第一堆泥土中任意移挪动泥土,具体地从第 i 个位置移动 k 克泥土到第 j 个位置,但是会消耗 的体力。小埃的最终目的是通过在第一堆中挪动泥土,使得第一堆泥土最终的形态和第二堆相同,也就是, 但是要求所花费的体力最小。
设想一个二维空间,真实数据分布是X轴为零,Y轴为随机变量的分布,而生成数据的分布是X轴为 ,Y轴为随机变量的分布,是生成数据分布的一个变量。根据上述四个公式:
也就是说当 逼近零时候,只有EM距离在减小,而其他几种距离的公式都是一个固定的值或者无穷大。EM
距离具备一个连续可用的梯度。
对于真实数据分布的输入x与生成数据分布的输入x,求满足1-Liposchitz条件的函数f(x)的期望值差值的上确界。
根据1-Liposchitz条件成立,继续改写成
继续对比GAN和WGAN
看一下WGAN的伪代码:
①分别从真实数据分布和前置随机分布中采样批次。然后进行梯度下降训练判别器:
②结束训练后再从前置随机分布中采样一个批次,使用梯度法训练生成器:
③完整伪代码:
具体的差别可以看NG视频的笔记[coursera/ImprovingDL/week2]Optimization algorithms
使用keras实现。
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K
import matplotlib.pyplot as plt
import sys
import numpy as np
class WGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Following parameter and optimizer set as recommended in paper
self.n_critic = 5
self.clip_value = 0.01
optimizer = RMSprop(lr=0.00005)
# Build and compile the critic
self.critic = self.build_critic()
self.critic.compile(loss=self.wasserstein_loss,
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generated imgs
z = Input(shape=(100,))
img = self.generator(z)
# For the combined model we will only train the generator
self.critic.trainable = False
# The critic takes generated images as input and determines validity
valid = self.critic(img)
# The combined model (stacked generator and critic)
self.combined = Model(z, valid)
self.combined.compile(loss=self.wasserstein_loss,
optimizer=optimizer,
metrics=['accuracy'])
def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_critic(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
训练过程使用权重裁剪使得网络参数保持在一定范围内
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
for epoch in range(epochs):
for _ in range(self.n_critic):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the critic
d_loss_real = self.critic.train_on_batch(imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
# Clip critic weights
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)
# ---------------------
# Train Generator
# ---------------------
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss[0]))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
由于训练速度原因放出前5轮训练结果
0 [D loss: 0.999914] [G loss: 1.000178]
50 [D loss: 0.999974] [G loss: 1.000072]
100 [D loss: 0.999964] [G loss: 1.000120]
150 [D loss: 0.999967] [G loss: 1.000081]
从第一、二组看出,随着W距离的降低,图像生成质量越来越高;
随着生成器的迭代此处上升,一开始W距离快速下降,慢慢变温度;
最后一组实验不好,随着生成器迭代次数上升,W距离没有下降,但也看到实验效果没有变好,说明理论仍然正确。
可以看出JS散度变化和生成图像效果没有正相关。且JS散度值趋近常数log2,约等于0.69,最后一组也可以发现两者没有关联。
随着网络的训练,生成器产生的结果是在各个点之间跳跃,但是每次只能产生一个点的数据。
研究人员发表了一些解决模式崩溃的方法,
例如:minibatch:Improved Techniques for Training GANs(NIPs 2016, Ian Goodfellow)
UnrolledGAN:UNROLLED GENERATIVE ADVERSARIAL NETWORKS(ICLR 2017)
参考令人拍案叫绝的Wasserstein GAN