【GANs】Wasserstein GAN

【GANs】Wasserstein GAN

  • 4 W-GAN
    • 4.1 W-GAN简介
      • 评价网络
      • 生成网络
    • 4.2 散度
      • KL散度
      • JS散度
      • Wasserstein距离
    • 4.3 代码实现

4 W-GAN

在生成对抗网络中, J S JS JS散度不适合衡量生成数据分布和真实数据分布的距离。由于通过优化交叉熵( J S JS JS散度)训练生成对抗网络会导致训练稳定性和模型坍塌问题,因此改进GAN,就需要改变其损失函数。

4.1 W-GAN简介

Wasserstein GAN原文链接

W-GAN通过使用 W a s s e r s t e i n Wasserstein Wasserstein距离代替优化 J S JS JS散度来优化训练的生成对抗网络。
对于真实分布 p r p_r pr和模型分布 p θ p_{\theta} pθ,他们的 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离为:
W 1 ( p r , p θ ) = inf ⁡ γ ∼ Γ ( p r , p θ ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] \begin{align} {{\bf{W}}^1}({p_r},{p_\theta }) = \mathop {\inf }\limits_{\gamma \sim \Gamma ({p_r},{p_\theta })} {{\rm E}_{(x,y)\sim \gamma }}\left[ {\left\| {x - y} \right\|} \right] \end{align} W1(pr,pθ)=γΓ(pr,pθ)infE(x,y)γ[xy]
其中 Γ ( p r , p θ ) {\Gamma ({p_r},{p_\theta }}) Γ(pr,pθ)是边界分布为 p r p_r pr p θ p_{\theta} pθ的所有可能的联合分布集合。

当两个分布没有重叠或者重叠非常少时,他们之间的 K L KL KL散度为 + ∞ + \infty + J S JS JS散度为 l o g 2 log2 log2,并不随着两个分布之间的距离而变化。而 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离依然可以衡量两个没有重叠分布之间的距离。

两个分布 p r p_r pr p θ p_{\theta} pθ 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离通常难以直接计算,但是两个分布的 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离有一个对偶形式:
W 1 ( p r , p θ ) = sup ⁡ ∥ f ∥ L ≤ 1 ( E x ∼ p r [ f ( x ) ] − E x ∼ p θ [ f ( x ) ] ) \begin{align} {{\bf{W}}^1}({p_r},{p_\theta }) = \mathop {\sup }\limits_{{{\left\| f \right\|}_L} \le 1} \left( {{{\rm E}_{x\sim{p_r}}}\left[ {f(x)} \right] - {{\rm E}_{x\sim{p_\theta }}}\left[ {f(x)} \right]} \right) \end{align} W1(pr,pθ)=fL1sup(Expr[f(x)]Expθ[f(x)])
其中 f : R d → R f:{\mathbb{R}^d} \to \mathbb{R} f:RdR 1 − L i p s c h i t z 1-Lipschitz 1Lipschitz函数,满足:
∥ f ∥ L ≤ 1 ≜ sup ⁡ x ≠ y ∣ f ( x ) − f ( y ) ∣ ∣ x − y ∣ ⩽ 1 \begin{align} {{{\left\| f \right\|}_L} \le 1} \triangleq \mathop {\sup }\limits_{x \ne y} \frac{{\left| {f(x) - f(y)} \right|}}{{\left| {x - y} \right|}} \leqslant 1 \end{align} fL1x=ysupxyf(x)f(y)1
公式 ( 3 ) (3) (3)称为 K a n t o r o v i c h − R u b i n s t e i n Kantorovich-Rubinstein KantorovichRubinstein对偶定理。

根据 K a n t o r o v i c h − R u b i n s t e i n Kantorovich-Rubinstein KantorovichRubinstein对偶定理,两个分布 p r p_r pr p θ p_{\theta} pθ之间的 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离可以转化为一个满足 1 − L i p s c h i t z 1-Lipschitz 1Lipschitz连续的函数在分布 p r p_r pr p θ p_{\theta} pθ下期望的差的上界。通常情况下, 1 − L i p s c h i t z 1-Lipschitz 1Lipschitz连续的约束可以宽松为 K − L i p s c h i t z K-Lipschitz KLipschitz连续。这样分布 p r p_r pr p θ p_{\theta} pθ之间的 1 s t − W a s s e r s t e i n 1st-Wasserstein 1stWasserstein距离为:
W 1 ( p r , p θ ) = 1 K sup ⁡ ∥ f ∥ L ≤ K ( E x ∼ p r [ f ( x ) ] − E x ∼ p θ [ f ( x ) ] ) \begin{align} {{\bf{W}}^1}({p_r},{p_\theta }) =\frac{1}{K} \mathop {\sup }\limits_{{{\left\| f \right\|}_L} \le K} \left( {{{\rm E}_{x\sim{p_r}}}\left[ {f(x)} \right] - {{\rm E}_{x\sim{p_\theta }}}\left[ {f(x)} \right]} \right) \end{align} W1(pr,pθ)=K1fLKsup(Expr[f(x)]Expθ[f(x)])

评价网络

计算 ( 4 ) (4) (4)的上界并不容易。根据神经网络的通用近似定理,假设存在一个神经网络使得可以达到这个上界。

f ( x ; ϕ ) f(x;\phi) f(x;ϕ)为一个神经网络,假设存在参数集合 Φ \Phi Φ,对于所有的 ϕ ∈ Φ \phi \in \Phi ϕΦ f ( x ; ϕ ) f(x;\phi) f(x;ϕ) K − L i p s c h i t z K-Lipschitz KLipschitz连续函数,那么 ( 4 ) (4) (4)中的上界可以近似转换为:
max ⁡ ϕ ∈ Φ ( E x ∼ p r [ f ( x ; ϕ ) ] − E x ∼ p θ [ f ( x ; ϕ ) ] ) \begin{align} \mathop {\max }\limits_{\phi \in \Phi} \left( {{{\rm E}_{x\sim{p_r}}}\left[ {f(x;\phi)} \right] - {{\rm E}_{x\sim{p_\theta }}}\left[ {f(x;\phi)} \right]} \right) \end{align} ϕΦmax(Expr[f(x;ϕ)]Expθ[f(x;ϕ)])
其中 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)称为评价网络(Critic Network)

与标准GAN中的判别网络的值域为 [ 0 , 1 ] [0,1] [0,1]不同,评价网络 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)的最后一层为线性层,其值域没有限制。这样只需要找到一个网络 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)使其在两个分布 p r p_r pr p θ p_{\theta} pθ下的期望最大。即对于真实样本, f ( x ; ϕ ) f(x;\phi) f(x;ϕ)的打分要尽可能低。

为了使得 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)满足 K − L i p s c h i t z K-Lipschitz KLipschitz连续,一种近似的方法使限制参数的取值范围。因为神经网络为连续可导函数,满足 K − L i p s c h i t z K-Lipschitz KLipschitz连续可以近似为其关于 x x x的偏导数的模 ∥ ∂ f ( x ; ϕ ) ∂ x ∥ \left\| {\frac{{\partial f(x;\phi )}}{{\partial x}}} \right\| xf(x;ϕ) 小于某个上界。由于这个偏导数的大小一般和参数的取值范围相关,我们可以通过限制参数 ϕ \phi ϕ的取值范围来近似,令 ϕ ∈ [ − c , c ] \phi \in[-c,c] ϕ[c,c] c c c为一个较小的正数,比如 0.01 0.01 0.01

生成网络

生成网络的目标是使得评价网络 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)对其生成样本的打分尽可能高,即
max ⁡ θ E z ∼ p ( z ) [ f ( G ( z ; θ ) ; ϕ ) ] \begin{align} \mathop {\max }\limits_{\theta} {{{\rm E}_{z\sim{p_{(z)}}}}\left[ {f(G(z;\theta);\phi)} \right] } \end{align} θmaxEzp(z)[f(G(z;θ);ϕ)]
因为 f ( x ; ϕ ) f(x;\phi) f(x;ϕ)为不饱和函数,所以生成网络参数 θ \theta θ的梯度不会消失,理论上解决了原始GAN训练不稳定的问题。并且W-GAN中生成网络的目标函数不再是两个分布的比率,在一定程度上缓解了模型坍塌问题,使得生成的样本具有多样性。

4.2 散度

KL散度

K L 散度 ( K u l l b a c k − L e i b l e r D i v e r g e n c e ) KL散度(Kullback-Leibler Divergence) KL散度(KullbackLeiblerDivergence)也叫 K L 距离 KL距离 KL距离 相对熵 ( R e l a t i v e E n t r o p y ) 相对熵(Relative Entropy) 相对熵(RelativeEntropy),是用概率分布 q q q来近似分布 p p p时造成的信息的损失量。 K L 散度 KL散度 KL散度是按照概率分布 q q q的最优编码来对真实分布为 p p p的信息进行编码,其平均编码长度(即交叉熵) H ( p , q ) H(p,q) H(p,q) p p p的最优平均编码长度(即熵) H ( p ) H(p) H(p)之间的差异。对于离散概率分布 p p p q q q,从 q q q p p p K L 散度 KL散度 KL散度定义为
K L ( p , q ) = H ( p , q ) − H ( p ) = ∑ x p ( x ) log ⁡ p ( x ) q ( x ) \begin{align} KL(p,q) &= H(p,q) - H(p) \\ &= \sum\limits_x {p(x)\log \frac{{p(x)}}{{q(x)}}} \\ \end{align} KL(p,q)=H(p,q)H(p)=xp(x)logq(x)p(x)
其中,为了保证连续性,定义 0 l o g 0 0 = 0 , 0 l o g 0 q = 0 0log\frac{0}{0}=0,0log\frac{0}{q}=0 0log00=0,0logq0=0

K L 散度 KL散度 KL散度总是非负的, K L ( p , q ) ≥ 0 KL(p,q)\ge0 KL(p,q)0,可以衡量两个概率分布之间的距离。 K L 散度 KL散度 KL散度只有当 p = q p=q p=q时, K L ( p , q ) = 0 KL(p,q)=0 KL(p,q)=0

如果两个分布越接近, K L 散度 KL散度 KL散度越小;反之亦反。

K L 散度 KL散度 KL散度并不是一个真正的度量或距离,一是 K L 散度 KL散度 KL散度不满足距离的对称性,二是 K L 散度 KL散度 KL散度不满足距离的三角不等式性质。

JS散度

J S 散度 ( J e n s e n − S h a n n o n D i v e r g e n c e ) JS散度(Jensen-Shannon Divergence) JS散度(JensenShannonDivergence)是一种对称的衡量两个分布相似度的度量方式,定义为
J S ( p , q ) = 1 2 K L ( p , m ) + 1 2 K L ( q , m ) \begin{align} JS(p,q) = \frac{1}{2}KL(p,m) + \frac{1}{2}KL(q,m) \end{align} JS(p,q)=21KL(p,m)+21KL(q,m)
其中 m = 1 2 ( p + q ) m=\frac{1}{2}(p+q) m=21(p+q)

J S 散度 JS散度 JS散度 K L 散度 KL散度 KL散度一种改进,但是两种散度都存在一个问题,即如果两个分布 p , q p,q p,q没有重叠或者重叠非常少, K L 散度 KL散度 KL散度 J S 散度 JS散度 JS散度都很难衡量两个分布的距离。

Wasserstein距离

W a s s e r s t e i n 距离 ( W a s s e r s t e i n D i s t a n c e ) Wasserstein距离(Wasserstein Distance) Wasserstein距离(WassersteinDistance)也用于衡量两个分布之间的距离。对于两个分布 q 1 , q 2 q_1,q_2 q1,q2 p t h − W a s s e r s t e i n p^{th}-Wasserstein pthWasserstein距离定义为
W p ( q 1 , q 2 ) = ( inf ⁡ γ ( x , y ) ∈ Γ ( q 1 , q 2 ) E ( x , y ) ∼ γ ( x , y ) [ d ( x , y ) p ] ) 1 p \begin{align} {{\bf{W}}_p}({q_1},{q_2 }) =( {\mathop {\inf }\limits_{\gamma(x,y) \in \Gamma ({q_1},{q_2 })} {{\rm E}_{(x,y)\sim \gamma(x,y) }}\left[ d{(x,y)}^{p} \right]})^{\frac{1}{p}} \end{align} Wp(q1,q2)=(γ(x,y)Γ(q1,q2)infE(x,y)γ(x,y)[d(x,y)p])p1
其中 Γ ( q 1 , q 2 ) \Gamma ({q_1},{q_2 }) Γ(q1,q2)是边际分布为 q 1 q_1 q1 q 2 q_2 q2的所有可能的联合分布集合, d ( x , y ) d(x,y) d(x,y) x x x y y y的距离。

W a s s e r s t e i n 距离 Wasserstein距离 Wasserstein距离相比 J S 散度 JS散度 JS散度 K L 散度 KL散度 KL散度的优势在于:
即使两个分布没有重叠或者重叠非常少, W a s s e r s t e i n 距离 Wasserstein距离 Wasserstein距离仍然能反应两个分布的远近。

4.3 代码实现

# @File    : WGAN_2017.py
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Input
from tensorflow.keras.layers import UpSampling2D, Conv2D, Activation, ZeroPadding2D, GlobalAveragePooling2D
from tensorflow.keras.layers import Flatten, Dropout
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import RMSprop
import tensorflow.keras.backend as K


class WGAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        self.critic = self.build_critic()
        self.critic.compile(loss=self.wasserstein_loss,
                            optimizer=optimizer,
                            metrics=['accuracy'])

        self.generator = self.build_generator()

        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        self.critic.trainable = False

        valid = self.critic(img)

        self.combined = Model(z, valid)
        self.combined.compile(loss=self.wasserstein_loss,
                              optimizer=optimizer,
                              metrics=['accuracy'])

    def build_generator(self):
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_critic(self):
        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def train(self, epochs, batch_size=128, sample_interval=50):

        (X_train, _), (_, _) = mnist.load_data()

        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        valid = -np.ones((batch_size, 1))
        fake = np.ones((batch_size, 1))

        for epoch in range(epochs):

            for _ in range(self.n_critic):

                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                gen_imgs = self.generator.predict(noise)

                d_loss_real = self.critic.train_on_batch(imgs, valid)
                d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                for l in self.critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)

            g_loss = self.combined.train_on_batch(noise, valid)

            print("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss[0]))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    wgan = WGAN()
    wgan.train(epochs=20000, batch_size=32, sample_interval=100)

tree

test
│  WGAN_2017.py
└─ images

你可能感兴趣的:(GANs,生成对抗网络,机器学习,深度学习)