生成对抗网络GAN(Generative adversarial nets)是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator) G G G和判别网络(Discriminator) D D D,为描述简单,以图像生成为例:
生成网络(Generator) G G G用于生成图片,其输入是一个随机的噪声 z \boldsymbol{z} z,通过这个噪声生成图片,记作 G ( z ) G\left ( \boldsymbol{z} \right ) G(z)。
判别网络(Discriminator) D D D用于判别一张图片是否是真实的,对应的,其输入是一整图片 x \boldsymbol{x} x,输出 D ( x ) D\left ( \boldsymbol{x} \right ) D(x)表示的是图片 x \boldsymbol{x} x为真实图片的概率。
在GAN框架的训练过程中,希望生成网络 G G G生成的图片尽量真实,能够欺骗过判别网络 D D D;而希望判别网络 D D D能够把 G G G生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络 G G G生成的图片能够以假乱真,即对于判别网络 D D D来说,无法判断 G G G生成的网络是不是真实的。
综上,训练好的生成网络 G G G便可以用于生成“以假乱真”的图片。
GAN的框架是由生成网络 G G G和判别网络 D D D这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练。由生成网络 G G G生成一张“Fake image”,判别网络 D D D判断这张图片是否来自真实图片。
生成网络和判别网络更像是自然界的一对生产者和捕食者,它们在各自独立迭代,进化的同时,也伴随着协同学习,共同进步。用论文里的一句话来描述训练过程就是
The training procedure for G is to maximize the probability of D making a mistake.
GAN希望的是对于判别网络,其能够正确判定数据是否来自真实的分布,对于生成网络,其能够尽可能使得生成的数据能够“以假乱真”,使得判别网络分辨不了。这样的训练过程是一个动态的“博弈”过程,通过交替训练,最终使得生成网络 G G G生成的图片能够“以假乱真”,其具体过程如下图所示:
如上图(a)中,黑色的虚线表示的是从真实的分布 p x p_{\boldsymbol{x}} px,绿色的实线表示的是需要训练的生成网络的生成的分布 p g ( G ) p_g\left ( G \right ) pg(G),蓝色的虚线表示的是判别网络,最下面的横线 z \boldsymbol{z} z表示的是从一个先验分布(如图中是一个均匀分布)采样得到的数据点,中间的横线 x \boldsymbol{x} x表示的真实分布,两条横线之间的对应关系表示的是生成网络将先验分布映射成一个生成分布 p g ( G ) p_g\left ( G \right ) pg(G)。从图(a)到图(d)表示了一个完整的交替训练过程,首先,如图(a)所示,当通过先验分布采样后的数据经过生成网络 G G G映射后得到了图上绿色的实线代表的分布,此时判别网络 D D D并不能区分数据是否来自真实数据,通过对判别网络的训练,其能够正确地判断生成的数据是否来自真实数据,如图(b)所示;此时更新生成网络 G G G,通过对先验分布重新映射到新的生成分布上,如图©中的绿色实线所示。依次交替完成上述步骤,当达到一定迭代的代数后,达到一个平衡状态,此时 p g = p d a t a p_g=p_{data} pg=pdata,判别网络 D D D将不能区分图片是否来自真实分布,且 D ( x ) = 1 2 D\left ( \boldsymbol{x} \right )=\frac{1}{2} D(x)=21。
首先对于二分类问题,一般使用交叉熵作为损失函数:
J ( θ ) = − 1 m ∑ i = 1 m [ y ( i ) l o g y ^ ( i ) + ( 1 − y ( i ) ) l o g ( 1 − y ^ ( i ) ) ] J\left ( \theta \right )=-\frac{1}{m}\sum_{i=1}^{m}\left [ y^{\left ( i \right )}log\; \hat{y}^{\left ( i \right )}+\left ( 1-y^{\left ( i \right )} \right )log\; \left ( 1-\hat{y}^{\left ( i \right )} \right ) \right ] J(θ)=−m1i=1∑m[y(i)logy^(i)+(1−y(i))log(1−y^(i))]其中, y ( i ) y^{\left ( i \right )} y(i) 表示的是真实的样本标签, y ^ ( i ) \hat{y}^{\left ( i \right )} y^(i)表示的是模型的预测值。
对于GAN的两部分样本,一个是真实的样本 { ( x ( 1 ) , 1 ) , ( x ( 2 ) , 1 ) , ⋯ , ( x ( m ) , 1 ) } \left \{ \left ( \boldsymbol{x}^{\left ( 1 \right )},1 \right ),\left ( \boldsymbol{x}^{\left ( 2 \right )},1 \right ),\cdots ,\left ( \boldsymbol{x}^{\left ( m \right )},1 \right ) \right \} {(x(1),1),(x(2),1),⋯,(x(m),1)},另一部分来自生成模型 { ( G ( z ( 1 ) ) , 0 ) , ( G ( z ( 2 ) ) , 0 ) , ⋯ , ( G ( z ( m ) ) , 0 ) } \left \{ \left ( G\left ( \boldsymbol{z}^{\left ( 1 \right )} \right ),0 \right ),\left ( G\left ( \boldsymbol{z}^{\left ( 2 \right )} \right ),0 \right ),\cdots ,\left ( G\left ( \boldsymbol{z}^{\left ( m \right )} \right ),0 \right ) \right \} {(G(z(1)),0),(G(z(2)),0),⋯,(G(z(m)),0)}。
将两部分带入交叉熵,可得GAN价值函数 V ( G , D ) V\left ( G,D \right ) V(G,D)为:
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \underset{G}{min}\; \underset{D}{max}\; V\left ( D,G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; D\left ( \boldsymbol{x} \right ) \right ]+\mathbb{E}_{\boldsymbol{z}\sim p_{\boldsymbol{z}}\left ( \boldsymbol{z} \right )}\left [ log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right ) \right ] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中, E x ∼ p d a t a ( x ) [ l o g D ( x ) ] \mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; D\left ( \boldsymbol{x} \right ) \right ] Ex∼pdata(x)[logD(x)]表示的是 D ( x ) D\left ( \boldsymbol{x} \right ) D(x)的期望, E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{\boldsymbol{z}\sim p_{\boldsymbol{z}}\left ( \boldsymbol{z} \right )}\left [ log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right ) \right ] Ez∼pz(z)[log(1−D(G(z)))]表示的是 l o g ( 1 − D ( G ( z ) ) ) log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right ) log(1−D(G(z)))的期望。
假设从真实数据中采样 m m m个样本 { x ( 1 ) , x ( 2 ) , ⋯ , x ( m ) } \left \{ \boldsymbol{x}^{\left ( 1 \right )},\boldsymbol{x}^{\left ( 2 \right )},\cdots ,\boldsymbol{x}^{\left ( m \right )} \right \} {x(1),x(2),⋯,x(m)},从噪音分布 p g ( z ) p_g\left ( \boldsymbol{z} \right ) pg(z)中同样采样 m m m个样本,记为 { z ( 1 ) , z ( 2 ) , ⋯ , z ( m ) } \left \{ \boldsymbol{z}^{\left ( 1 \right )},\boldsymbol{z}^{\left ( 2 \right )},\cdots ,\boldsymbol{z}^{\left ( m \right )} \right \} {z(1),z(2),⋯,z(m)},此时,上述价值函数可以近似表示为:
m i n G m a x D V ( D , G ) ≈ ∑ i = 1 m [ l o g D ( x ( i ) ) ] + ∑ i = 1 m [ l o g ( 1 − D ( G ( z ( i ) ) ) ) ] \underset{G}{min}\underset{D}{max}V(D,G)\approx \sum_{i=1}^{m}{[logD(x^{(i)})]}+ \sum_{i=1}^{m}{[log(1−D(G(z^{(i)})))]} GminDmaxV(D,G)≈i=1∑m[logD(x(i))]+i=1∑m[log(1−D(G(z(i))))]
用随机梯度下降法训练过程在论文中如下流程所示:(分两步ascending,descending,加梯度减梯度交替训练生成器判别器)
需要刻画两个分布是否相似,需要用到KL散度(KL divergence)。KL散度是统计学中的一个基本概念,用于衡量两个分布的相似程度,数值越小,表示两种概率分布越接近。
对于离散的概率分布,定义如下:
( P ∥ Q ) = ∑ i P ( i ) l o g P ( i ) Q ( i ) \left ( P\parallel Q \right )=\sum_{i}P\left ( i \right )log\frac{P\left ( i \right )}{Q\left ( i \right )} (P∥Q)=i∑P(i)logQ(i)P(i)对于连续的概率分布,定义如下:
( P ∥ Q ) = ∫ − ∞ + ∞ p ( x ) l o g p ( x ) q ( x ) \left ( P\parallel Q \right )=\int_{-\infty }^{+\infty }p\left ( x \right )log\frac{p\left ( x \right )}{q\left ( x \right )} (P∥Q)=∫−∞+∞p(x)logq(x)p(x)
上述需要求解生成分布 p g ( x ; θ ) p_g\left ( \boldsymbol{x};\theta \right ) pg(x;θ)中的参数 θ \theta θ,需要用到极大似然估计。根据极大似然估计的方式,由于最终是希望生成的分布 p g ( x ; θ ) p_g\left ( \boldsymbol{x};\theta \right ) pg(x;θ)与原始的真实分布 p d a t a ( x ) p_{data}\left ( \boldsymbol{x} \right ) pdata(x),首先从真实分布 p d a t a ( x ) p_{data}\left ( \boldsymbol{x} \right ) pdata(x)采样 m m m个数据点,记为 { x ( 1 ) , x ( 2 ) , ⋯ , x ( m ) } \left \{ \boldsymbol{x}^{\left ( 1 \right )},\boldsymbol{x}^{\left ( 2 \right )},\cdots ,\boldsymbol{x}^{\left ( m \right )} \right \} {x(1),x(2),⋯,x(m)},根据生成的分布,得到似然函数为:
L = ∏ i = 1 m p g ( x ( i ) ; θ ) L=\prod_{i=1}^{m}p_g\left ( \boldsymbol{x}^{\left ( i \right )};\theta \right ) L=i=1∏mpg(x(i);θ)
取log后,得到等价的log似然:
L = ∑ i = 1 m l o g p g ( x ( i ) ; θ ) L=\sum_{i=1}^{m}log\; p_g\left ( \boldsymbol{x}^{\left ( i \right )};\theta \right ) L=i=1∑mlogpg(x(i);θ)
此时, θ ∗ \theta ^{\ast } θ∗为:
θ ∗ = a r g m a x θ ∑ i = 1 m l o g p g ( x ( i ) ; θ ) θ∗=arg\underset{\theta}{max}\sum_{i=1}^{m}{logp_{g}(x^{(i)};θ)} θ∗=argθmax∑i=1mlogpg(x(i);θ)
≈ a r g m a x θ E x ∼ p d a t a [ l o g p g ( x ; θ ) ] ≈arg\underset{\theta}{max}\mathbb{E}_{\boldsymbol{x}\sim p_{data}}[logp_{g}(x;θ)] ≈argθmaxEx∼pdata[logpg(x;θ)]
= a r g m a x θ ∫ x p d a t a ( x ) l o g p g ( x ; θ ) d x =arg\underset{\theta}{max}∫_{x}p_{data}(x)logp_{g}(x;θ)dx =argθmax∫xpdata(x)logpg(x;θ)dx
对上述的公式做一些修改,增加一个与 θ \theta θ无关的项 ∫ x p d a t a ( x ) l o g p d a t a ( x ) d x \int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x} ∫xpdata(x)logpdata(x)dx,这样并不改变对 θ ∗ \theta ^{\ast } θ∗的求解,此时,公式变为:
a r g m a x θ ∫ x p d a t a ( x ) l o g p g ( x ; θ ) d x − ∫ x p d a t a ( x ) l o g p d a t a ( x ) d x \underset{\theta }{argmax}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_g\left ( \boldsymbol{x};\theta \right )d\boldsymbol{x}-\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x} θargmax∫xpdata(x)logpg(x;θ)dx−∫xpdata(x)logpdata(x)dx
将最大值求解变成最小值为:
a r g m i n θ ∫ x p d a t a ( x ) l o g p d a t a ( x ) d x − ∫ x p d a t a ( x ) l o g p g ( x ; θ ) d x \underset{\theta }{argmin}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_{data}\left ( \boldsymbol{x} \right )d\boldsymbol{x}-\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; p_g\left ( \boldsymbol{x};\theta \right )d\boldsymbol{x} θargmin∫xpdata(x)logpdata(x)dx−∫xpdata(x)logpg(x;θ)dx
通过积分公式的合并,得到:
a r g m i n θ ∫ x p d a t a ( x ) l o g p d a t a ( x ) l o g p g ( x ; θ ) d x \underset{\theta }{argmin}\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\;\frac{p_{data}\left ( \boldsymbol{x} \right )}{log\; p_g\left ( \boldsymbol{x};\theta \right )} d\boldsymbol{x} θargmin∫xpdata(x)loglogpg(x;θ)pdata(x)dx
由KL散度可知,上述可以表示为:
a r g m i n θ K L ( p d a t a ( x ) ∥ p g ( x ; θ ) ) \underset{\theta }{argmin}\; KL\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel p_g\left ( \boldsymbol{x};\theta \right ) \right ) θargminKL(pdata(x)∥pg(x;θ))
由此可以看出最小化KL散度等价于最大化似然函数。
当生成网络G确定后,价值函数可以表示为:
m a x D V ( D ) = ∫ x p d a t a ( x ) l o g ( D ( x ) ) d x + ∫ z p z ( z ) l o g ( 1 − D ( G ( z ) ) ) d z \underset{D}{max}\;V(D)=∫_{x}p_{data}(x)log(D(x))dx+∫_{z}p_{z}(z)log(1−D(G(z)))dz DmaxV(D)=∫xpdata(x)log(D(x))dx+∫zpz(z)log(1−D(G(z)))dz
= ∫ x p d a t a ( x ) l o g ( D ( x ) ) + p g ( x ) l o g ( 1 − D ( x ) ) d x =∫_{x}p_{data}(x)log(D(x))+p_{g}(x)log(1−D(x))dx =∫xpdata(x)log(D(x))+pg(x)log(1−D(x))dx
由于上述的积分与D无关,上述可以简化成求解:
m a x D [ p d a t a ( x ) l o g ( D ( x ) ) + p g ( x ) l o g ( 1 − D ( x ) ) ] \underset{D}{max}\left [ p_{data}\left ( \boldsymbol{x} \right )log\; \left ( D\left ( \boldsymbol{x} \right ) \right )+p_{g}\left ( \boldsymbol{x} \right )log\; \left ( 1-D\left ( \boldsymbol{x} \right ) \right )\right ] Dmax[pdata(x)log(D(x))+pg(x)log(1−D(x))]
求导数并令其为0,便可以得到最大的D:
D ∗ = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D^{\ast }=\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} D∗=pdata(x)+pg(x)pdata(x)
且 D ( x ) ∈ [ 0 , 1 ] D\left ( \boldsymbol{x} \right )\in \left [ 0,1 \right ] D(x)∈[0,1],将其带入到价值函数中,可得
V ( D ∗ , G ) = E x ∼ p d a t a ( x ) [ l o g p d a t a ( x ) p d a t a ( x ) + p g ( x ) ] + E x ∼ p g ( x ) [ l o g ( 1 − p d a t a ( x ) p d a t a ( x ) + p g ( x ) ) ] V\left ( D^{\ast },G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \left ( 1-\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ) \right ] V(D∗,G)=Ex∼pdata(x)[logpdata(x)+pg(x)pdata(x)]+Ex∼pg(x)[log(1−pdata(x)+pg(x)pdata(x))]
对上式简化,可得:
V ( D ∗ , G ) = E x ∼ p d a t a ( x ) [ l o g p d a t a ( x ) p d a t a ( x ) + p g ( x ) ] + E x ∼ p g ( x ) [ l o g ( 1 − p d a t a ( x ) p d a t a ( x ) + p g ( x ) ) ] V\left ( D^{\ast },G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \left ( 1-\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ) \right ] V(D∗,G)=Ex∼pdata(x)[logpdata(x)+pg(x)pdata(x)]+Ex∼pg(x)[log(1−pdata(x)+pg(x)pdata(x))]
= ∫ x p d a t a ( x ) l o g p d a t a ( x ) p d a t a ( x ) + p g ( x ) d x + ∫ x p g ( x ) l o g p g ( x ) p d a t a ( x ) + p g ( x ) d x =∫_{x}p_{data}(x)log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}dx+∫_{x}p_{g}(x)log\frac{p_{g}(x)}{p_{data}(x)+p_{g}(x)}dx =∫xpdata(x)logpdata(x)+pg(x)pdata(x)dx+∫xpg(x)logpdata(x)+pg(x)pg(x)dx
V ( D ∗ , G ) = ∫ x p d a t a ( x ) l o g 1 2 p d a t a ( x ) p d a t a ( x ) + p g ( x ) 2 d x + ∫ x p g ( x ) l o g 1 2 p g ( x ) p d a t a ( x ) + p g ( x ) 2 d x V\left ( D^{\ast },G \right )=∫_{x}p_{data}(x)log\frac{\frac{1}{2}p_{data}(x)}{\frac{p_{data}(x)+p_{g}(x)}{2}}dx+∫_{x}p_{g}(x)log\frac{\frac{1}{2}p_{g}(x)}{\frac{p_{data}(x)+p_{g}(x)}{2}}dx V(D∗,G)=∫xpdata(x)log2pdata(x)+pg(x)21pdata(x)dx+∫xpg(x)log2pdata(x)+pg(x)21pg(x)dx
= − 2 l o g 2 + K L ( p d a t a ( x ) ∥ p d a t a ( x ) + p g ( x ) 2 ) + K L ( p g ( x ) ∥ p d a t a ( x ) + p g ( x ) 2 ) =−2log2+KL(p_{data}(x)∥\frac{p_{data}(x)+p_{g}(x)}{2})+KL(p_{g}(x)∥\frac{p_{data}(x)+p_{g}(x)}{2}) =−2log2+KL(pdata(x)∥2pdata(x)+pg(x))+KL(pg(x)∥2pdata(x)+pg(x))
这里引入另一个散度:JS散度(Jensen-Shannon Divergence)
J S D ( P ∥ Q ) = 1 2 [ K L ( P ∥ M ) + K L ( Q ∥ M ) ] JSD\left ( P\parallel Q \right )=\frac{1}{2}\left [ KL\left ( P\parallel M \right ) + KL\left ( Q\parallel M \right )\right ] JSD(P∥Q)=21[KL(P∥M)+KL(Q∥M)]
其中, M = P + Q 2 M=\frac{P+Q}{2} M=2P+Q 。因此 V ( D ∗ , G ) V\left ( D^{\ast },G \right ) V(D∗,G)可以表示为:
V ( D ∗ , G ) = − l o g 4 + 2 J S D ( p d a t a ( x ) ∥ p g ( x ) ) V\left ( D^{\ast },G \right )=-log\; 4+2JSD\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel p_g\left ( \boldsymbol{x} \right ) \right ) V(D∗,G)=−log4+2JSD(pdata(x)∥pg(x))
已知JS散度是一个非负值,且值域为 [ 0 , 1 ] \left [ 0,1 \right ] [0,1],当两个分布相同时取0,不同时取1。对于 V ( D ∗ , G ) V\left ( D^{\ast },G \right ) V(D∗,G)的最小值为当 J S = 0 JS=0 JS=0时,即最小值是 − l o g 4 -log\;4 −log4。此时 p d a t a ( x ) = p g ( x ) p_{data}\left ( \boldsymbol{x} \right )=p_g\left ( \boldsymbol{x} \right ) pdata(x)=pg(x),求得的生成网络 G G G生成的数据分布与真实的数据分布差异性最小,即GAN所要求的目标: p g = p d a t a p_g=p_{data} pg=pdata。
本周重点是研究Ian Goodfellow在2014年发表的论文《Generative Adversarial Nets》中提出的GAN网络,结合李宏毅老师的深度学习课程相关部分针对对抗生成网络进行系统的了解学习。
GAN的生成器和判别器就像自然界的一对物种,生成器为了骗过判别器不断优化,判别器也通过不断训练努力提升辨别能力,最终我们将会得到一个非常优秀的生成器。
在训练过程推导价值函数时,接触到了散度这个新概念,结合之前学过的极大似然估计,发现这里的最小化KL散度等价于最大化似然函数。JSD散度相比较于KL散度,具备了对称性,且JS 散度的取值为 0 到 log2。若两个分布完全没有交集,那么 JS 散度取最大值 log2;若两个分布完全一样,那么 JS 散度取最小值 0,当且仅当 P=Q,JSD散度取最小值0,求得的G会使得真实的数据分布和生成的数据分布差异性最小,这样自然可以生成一个和原分布尽可能接近的分布,同时也摆脱了计算极大似然估计,所以GAN的本质是通过改变训练的过程来避免繁琐的计算。所以GAN在建模数据分布的优势不言而喻,同时它的局限性也很明显:难训练不稳定,生成器判别器需要很好的同步,还有别的论文中提到的多次训练梯度消失问题,模式缺失(Mode Collapse)问题(即只生成简单重复样本点)。