一篇讲的很好的博客
理论推导博客
论文原文
斯坦福课件
上面的博客已经很好很深入了,下面记录一下我个人的直观理解。具体理论移步上面的博客。这里只是作为日后使用时的快速查阅。不具有理论推导的严谨性。
第一次接触VAE还是在World Model这篇论文。VAE主要由三部分组成:
VAE训练好后,可以用中间变量z作为其他模型的输入World Model就是这么做的,这样Encoder就相当于一个降维的作用。也可以将Decoder作为生成器,生成和训练集类似的样例,这就和GAN的功能类似。
本质上,VAE就是我给一堆输入到编码器,解码器能输出同样分布的输出。
生成模型的难题就是判断生成分布与真实分布的相似度,因为我们只知道两者的采样结果,不知道它们的分布表达式。
KL散度的虽然能衡量两种分布的近似度,但是必须知道分布的表达式
我们的假设是 p ( z ∣ x ) p(z|x) p(z∣x)是高斯分布。这是VAE模型的重点,正因为这个假设,我们才设计成如下模型:
当然,如果 p ( z ∣ x ) p(z|x) p(z∣x)是高斯分布, p ( z ) p(z) p(z)也满足正态分布。推理如下(不区分积分与求和):
p ( z ) = ∑ x p ( z ∣ x ) p ( x ) = ∑ x N ( 0 , 1 ) p ( x ) = N ( 0 , 1 ) ∑ x p ( x ) = N ( 0 , 1 ) p(z) = \sum_x p(z|x)p(x) = \sum_xN(0,1)p(x) = N(0,1)\sum_xp(x) = N(0,1) p(z)=∑xp(z∣x)p(x)=∑xN(0,1)p(x)=N(0,1)∑xp(x)=N(0,1)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gIG4oJZ3-1605507771174)(24-VAE.assets/image-20201115160705282.png)]
图片来源
结构示意图如上图所示。
以图片为例:
对于均值方差计算模块:可能是多个卷积层和池化层
生成器:可能是多个反卷积层
均值和方差的计算则是全连接网络。
我们的z是根据均值和方差采样而来,在这里方差相当于噪声,如果方差是0的话,则采样结果则一定是均值。我们通过最小化生成的 x ^ \hat x x^与输入的 x x x之间的距离,来进行训练。那么我们的方差网络会逐渐趋近于结果为0。这时就退化成了AutoEncoder。
VAE通过在损失函数中引入生成的高斯分布 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)与标准的高斯分布 N ( 0 , 1 ) N(0,1) N(0,1)之间的KL散度,来让 p ( z ∣ x ) p(z|x) p(z∣x)的分布趋近于标准正态分布。
VAE相对于之前的AutoEncoder的一个显著提升就是它的生成能力。从正态分布中采样生成一个z,就可以生成一个比较合理的结果。而AutoEncoder不能保证中间的z向量是某一种分布,所以它对于没有见过的(训练过的)z生成能力比较差。
直接sample出z是不行的,采样的过程是不可导,没办法BP啊!!
解决办法就是:
z = μ + ϵ × σ z = \mu + \epsilon \times \sigma z=μ+ϵ×σ
ϵ \epsilon ϵ 是从N(0,1)中采样来的。
这种技巧叫做重参数。反向传播时候,需要让z能够分别对 μ \mu μ和 σ \sigma σ求偏导,而对于 ϵ \epsilon ϵ则不需要对他求导。故才采样出来也没关系。
损失函数:
L o s s ( θ ) = D ( x , x ^ ) + K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) Loss(\theta) = D(x,\hat x) + KL(N(\mu, \sigma^2) || N(0,1)) Loss(θ)=D(x,x^)+KL(N(μ,σ2)∣∣N(0,1))
D ( x , x ^ ) D(x,\hat x) D(x,x^)是输入样本与生成样本之间的距离,可以使均方误等。
KL部分的推导,对于一维情况:
KL的公式:
D K L ( p ∣ ∣ q ) = ∑ i = 1 N = p ( x i ) log p ( x i ) q ( x i ) D_{KL}(p||q) = \sum_{i = 1}^N = p(x_i) \log \frac{p(x_i)}{q(x_i)} DKL(p∣∣q)=i=1∑N=p(xi)logq(xi)p(xi)
K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = ∫ 1 2 π σ 2 exp { − ( x − μ ) 2 2 σ 2 } × log { 1 2 π σ 2 exp ( − ( x − μ ) 2 / 2 σ ) 1 2 π exp ( − x 2 / 2 ) } d x = 一 顿 猛 如 虎 的 化 简 = 1 2 ∫ 1 2 π σ 2 exp { − ( x − μ ) 2 2 σ 2 } [ − log σ 2 + x 2 − ( x − μ ) 2 / σ 2 ] d x KL(N(\mu, \sigma^2) || N(0,1)) = \int \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} \times \log \{\frac{\frac{1}{\sqrt{2 \pi \sigma^2}}\exp(-(x-\mu)^2/2\sigma^)} {\frac{1}{\sqrt{2 \pi}}\exp(-x^2/2)}\} dx\\ =一顿猛如虎的化简 \\ =\frac{1}{2} \int \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} [-\log \sigma^2 + x^2 - (x-\mu)^2/\sigma^2]dx KL(N(μ,σ2)∣∣N(0,1))=∫2πσ21exp{2σ2−(x−μ)2}×log{2π1exp(−x2/2)2πσ21exp(−(x−μ)2/2σ)}dx=一顿猛如虎的化简=21∫2πσ21exp{2σ2−(x−μ)2}[−logσ2+x2−(x−μ)2/σ2]dx
积分结果计算:
可分成三个积分加和(就是分别乘以中括号里那三部分):
第一个是 − log σ 2 -\log \sigma^2 −logσ2可作为常数提出来,剩下是标准正态分布的积分值为1, 故结果为 − log σ 2 -\log \sigma^2 −logσ2
第二项是二阶矩,结果为 μ 2 + σ 2 \mu^2 + \sigma^2 μ2+σ2
第三项是
∫ − ∞ + ∞ 1 2 π σ 2 exp { − ( x − μ ) 2 2 σ 2 } ( − ( x − μ ) 2 / σ 2 ) d x = ∫ − ∞ + ∞ 1 2 π exp { − ( x − μ ) 2 2 σ 2 } ( − ( x − μ ) 2 / σ 2 ) d ( x − μ ) σ = − 1 2 π ∫ − ∞ + ∞ e − 1 2 t 2 t 2 d t = − 1 2 π ∫ − ∞ + ∞ e − 1 2 t 2 t d t 2 2 = − 2 1 2 π ∫ 0 + ∞ e − m 2 m 1 / 2 d m = − 2 1 π ∫ 0 + ∞ e − m m 3 2 − 1 d m = − 2 1 π Γ ( 3 2 ) Γ ( 3 2 ) = Γ ( 1 2 + 1 ) = 1 2 Γ ( 1 / 2 ) = 1 2 π 所 以 上 述 积 分 = − 2 1 π × 1 2 π = − 1 \int_{-\infty}^{+\infty} \frac{1}{\sqrt{2 \pi \sigma^2}}\exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} (-(x-\mu)^2/\sigma^2)dx \\ = \int_{-\infty}^{+\infty} \frac{1}{\sqrt{2 \pi }}\exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} (-(x-\mu)^2/\sigma^2)d \frac{(x-\mu)}{\sigma} \\ = - \frac{1}{\sqrt{2\pi}} \int_{-\infty}^{+\infty} e^{-\frac{1}{2}t^2} t^2 dt \\ = -\frac{1}{\sqrt{2\pi}} \int_{-\infty}^{+\infty} e^{-\frac{1}{2}t^2} t d \frac{t^2}{2} \\ = -2\frac{1}{\sqrt{2\pi}}\int_{0}^{+\infty} e^{-m} \sqrt 2 m^{1/2} dm \\ = -2\frac{1}{\sqrt \pi} \int_{0}^{+\infty} e^{-m} m^{\frac{3}{2}-1} dm \\ = - 2\frac{1}{\sqrt \pi} \Gamma(\frac{3}{2}) \\ \Gamma(\frac{3}{2}) = \Gamma(\frac{1}{2}+1) = \frac{1}{2}\Gamma(1/2) = \frac{1}{2} \sqrt{\pi} \\ 所以上述积分=- 2\frac{1}{\sqrt \pi} \times\frac{1}{2} \sqrt{\pi} = -1 ∫−∞+∞2πσ21exp{2σ2−(x−μ)2}(−(x−μ)2/σ2)dx=∫−∞+∞2π1exp{2σ2−(x−μ)2}(−(x−μ)2/σ2)dσ(x−μ)=−2π1∫−∞+∞e−21t2t2dt=−2π1∫−∞+∞e−21t2td2t2=−22π1∫0+∞e−m2m1/2dm=−2π1∫0+∞e−mm23−1dm=−2π1Γ(23)Γ(23)=Γ(21+1)=21Γ(1/2)=21π所以上述积分=−2π1×21π=−1
最终,KL散度为:
K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ( − log σ 2 + μ 2 + σ 2 − 1 ) KL(N(\mu, \sigma^2) || N(0,1)) = \frac{1}{2} (-\log \sigma^2 + \mu^2 + \sigma^2 -1) KL(N(μ,σ2)∣∣N(0,1))=21(−logσ2+μ2+σ2−1)
上述只是针对一个维度。如果一共有j个维度,则需要把每个维度的KL散度都想相加。