详解变分自编码器——VAE

文章目录

  • 详解变分自编码器——VAE
    • VAE的目标
    • 模型结构
    • 原理推导
    • 重参数技巧

本文将介绍另一生成模型——变分自编码器VAE。

详解变分自编码器——VAE

VAE全称(Variational Auto-Encoder)即变分自编码器。是一个生成模型。

了解VAE之间,我们先简单了解一下自编码器,也就是常说的Auto-Encoder

Auto-Encoder包括一个编码器(Encoder)和一个解码器(Decoder)。其结构如下:

详解变分自编码器——VAE_第1张图片

中间的这层code也称embedding。

VAE的目标

先假设一个隐变量Z的分布,构建一个从Z到目标数据X的模型,即构建 X = g ( Z ) X=g(Z) X=g(Z),使得学出来的目标数据与真实数据的概率分布相近。与GAN基本一致,GAN学的也是概率分布。

模型结构

VAE的结构图(图源自苏老师的博客,侵删)如下:

详解变分自编码器——VAE_第2张图片

VAE对每一个样本 X k X_k Xk匹配一个高斯分布,隐变量Z就是从高斯分布中采样得到的。对K个样本来说,每个样本的高斯分布假设为 N ( μ k , σ k 2 ) \mathcal N(\mu_k,\sigma_k^2) N(μk,σk2),问题就在于如何拟合这些分布。

VAE构建两个神经网络来进行拟合均值与方差。即 μ k = f 1 ( X k ) , l o g σ k 2 = f 2 ( X k ) \mu_k=f_1(X_k),log\sigma_k^2=f_2(X_k) μk=f1(Xk),logσk2=f2(Xk),拟合 l o g σ k 2 log\sigma_k^2 logσk2的原因是这样无需加激活函数。

此外,VAE让每个高斯分布尽可能地趋于标准高斯分布 N ( 0 , 1 ) \mathcal N(0,1) N(0,1)。这拟合过程中的误差损失则是采用KL散度作为计算。

下面做详细推导。

原理推导

其实,VAE与同为生成模型的GMM(高斯混合模型)也有很相似,实际上VAE可看成是GMM的一个distributed representation的版本。我们知道,GMM是有限个高斯分布的隐变量 z z z 的混合,而VAE可看成是无穷个隐变量 z z z 的混合,注意,VAE中的 z z z 可以是高斯也可以是非高斯的。只不过一般用的比较多的是高斯的。

原始样本数据 x x x 的概率分布:
P ( x ) = ∫ Z P ( x ) P ( x ∣ z ) d z P(x)=\int_Z P(x)P(x|z)dz P(x)=ZP(x)P(xz)dz
我们假设 z z z 服从标准高斯分布,先验分布 P ( x ∣ z ) P(x|z) P(xz) 是高斯的,即 x ∣ z ∼ N ( μ ( z ) , σ ( z ) ) x|z \sim N(\mu(z),\sigma(z)) xzN(μ(z),σ(z)) μ ( z ) 、 σ ( z ) \mu(z)、\sigma(z) μ(z)σ(z)是两个函数, 分别是 z z z对应的高斯分布的均值和方差(如下图),则 P ( x ) P(x) P(x) 就是在积分域上所有高斯分布的累加。

详解变分自编码器——VAE_第3张图片

由于 P ( z ) P(z) P(z) 是已知的, P ( x ∣ z ) P(x|z) P(xz) 未知,所以求解问题实际上就是求 μ , σ \mu,\sigma μ,σ这两个函数。我们最开始的目标是求解 P ( x ) P(x) P(x),且我们希望 P ( x ) P(x) P(x)越大越好,这等价于求解关于 x x x 最大对数似然:
L = ∑ x l o g P ( x ) L=\sum_x logP(x) L=xlogP(x)

l o g P ( x ) logP(x) logP(x) 可变换为:
l o g P ( x ) = ∫ z q ( z ∣ x ) l o g P ( x ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z \begin{aligned} logP(x)&=\int_z q(z|x)logP(x)dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{P(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)}\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)})dz+ \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz + \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz \end{aligned} logP(x)=zq(zx)logP(x)dz=zq(zx)log(P(zx)P(z,x))dz=zq(zx)log(q(zx)P(z,x)P(zx)q(zx))dz=zq(zx)log(q(zx)P(z,x))dz+zq(zx)log(P(zx)q(zx))dz=zq(zx)log(q(zx)P(xz)P(z))dz+zq(zx)log(P(zx)q(zx))dz

到这里我们发现,第二项 ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz zq(zx)log(P(zx)q(zx))dz 其实就是 q q q P P P 的KL散度,即 K L ( q ( z ∣ x )    ∣ ∣    P ( z ∣ x ) ) KL(q(z|x)\;||\;P(z|x)) KL(q(zx)P(zx)),因为KL散度是大于等于0的,

所以上式进一步可写成:
l o g P ( x ) ≥ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z logP(x)\geq \int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz logP(x)zq(zx)log(q(zx)P(xz)P(z))dz

这样我们就找到了一个下界(lower bound),也就是式子的右项,即
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z L_b=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz Lb=zq(zx)log(q(zx)P(xz)P(z))dz

原式也可表示成:
l o g P ( x ) = L b + K L ( q ( z ∣ x )    ∣ ∣    P ( z ∣ x ) ) logP(x)=L_b+KL(q(z|x)\;||\;P(z|x)) logP(x)=Lb+KL(q(zx)P(zx))

为了让 l o g P ( x ) logP(x) logP(x) 越大,我们目的就是要最大化它的这个下界。

推到这里,可能会有个疑问:为什么要引入 q ( z ∣ x ) q(z|x) q(zx)(这里的 q ( z ∣ x ) q(z|x) q(zx)可以是任何分布)?

实际上,因为后验分布 P ( z ∣ x ) P(z|x) P(zx) 很难求(intractable),所以才用 q ( z ∣ x ) q(z|x) q(zx) 来逼近这个后验分布。在优化的过程中我们发现,首先 q ( z ∣ x ) q(z|x) q(zx) l o g P ( x ) logP(x) logP(x) 是完全没有关系的, l o g P ( x ) logP(x) logP(x) 只跟 P ( z ∣ x ) P(z|x) P(zx) 有关,调节 q ( z ∣ x ) q(z|x) q(zx) 是不会影响似然也就是 l o g P ( x ) logP(x) logP(x) 的。所以,当我们固定住 P ( x ∣ z ) P(x|z) P(xz) 时,调节 q ( z ∣ x ) q(z|x) q(zx) 最大化下界 L b L_b Lb,KL则越小。当 q ( z ∣ x ) q(z|x) q(zx)与不断逼近后验分布 P ( z ∣ x ) P(z|x) P(zx)时,KL散度趋于为0, l o g P ( x ) logP(x) logP(x)就和 L b L_b Lb 等价。所以最大化 l o g P ( x ) logP(x) logP(x) 就等价于最大化 L b L_b Lb

详解变分自编码器——VAE_第4张图片

回顾 L b L_b Lb,
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x )    ∣ ∣    P ( z ) ) + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x )    ∣ ∣    P ( z ) ) + E q ( z ∣ x ) [ l o g ( P ( x ∣ z ) ) ] \begin{aligned} L_b&=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z)}{q(z|x)})dz+\int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + \int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + E_{q(z|x)}[log(P(x|z))] \end{aligned} Lb=zq(zx)log(q(zx)P(xz)P(z))dz=zq(zx)log(q(zx)P(z))dz+zq(zx)logP(xz)dz=KL(q(zx)P(z))+zq(zx)logP(xz)dz=KL(q(zx)P(z))+Eq(zx)[log(P(xz))]

显然,最大化 L b L_b Lb 就是等价于最小化 K L ( q ( z ∣ x )    ∣ ∣    P ( z ) ) KL(q(z|x)\;||\;P(z)) KL(q(zx)P(z)) 和最大化 E q ( z ∣ x ) [ l o g ( P ( x ∣ z ) ) ] E_{q(z|x)}[log(P(x|z))] Eq(zx)[log(P(xz))]

第一项,最小化KL散度。我们前面已假设了 P ( z ) P(z) P(z) 是服从标准高斯分布的,且 q ( z ∣ x ) q(z|x) q(zx) 是服从高斯分布 N ( μ , σ 2 ) \mathcal N(\mu,\sigma^2) N(μ,σ2) ,于是代入计算可得:
K L ( q ( z ∣ x )    ∣ ∣    P ( z ) ) = K L ( N ( μ , σ 2 )    ∣ ∣    N ( 0 , 1 ) ) = ∫ 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 ( l o g e − ( x − μ ) 2 2 σ 2 / 2 π σ 2 e − x 2 2 / 2 π ) d x . . . 化简得到 = 1 2 1 2 π σ 2 ∫ e − ( x − μ ) 2 2 σ 2 ( − l o g σ 2 + x 2 − ( x − μ ) 2 σ 2 ) d x = 1 2 ∫ 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 ( − l o g σ 2 + x 2 − ( x − μ ) 2 σ 2 ) d x \begin{aligned} KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=&\int\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left( log\dfrac{e^{\frac{-(x-\mu)^2}{2\sigma^2}}/\sqrt{2\pi\sigma^2}}{ e^{\frac{-x^2}{2}}/\sqrt{2\pi} } \right)dx \\&...\text{化简得到} \\=&\dfrac{1}{2}\dfrac{1}{\sqrt{2\pi\sigma^2}}\int e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \\=&\dfrac{1}{2}\int \dfrac{1}{\sqrt{2\pi\sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \end{aligned} KL(q(zx)P(z))=KL(N(μ,σ2)N(0,1))===2πσ2 1e2σ2(xμ)2loge2x2/2π e2σ2(xμ)2/2πσ2 dx...化简得到212πσ2 1e2σ2(xμ)2(logσ2+x2σ2(xμ)2)dx212πσ2 1e2σ2(xμ)2(logσ2+x2σ2(xμ)2)dx

对上式中的积分进一步求解, 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 \dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} 2πσ2 1e2σ2(xμ)2实际就是概率密度 f ( x ) f(x) f(x),而概率密度函数的积分就是1,所以积分第一项等于 − l o g σ 2 -log\sigma^2 logσ2;而又因为高斯分布的二阶矩就是 E ( X 2 ) = ∫ x 2 f ( x ) d x = μ 2 + σ 2 E(X^2)=\int x^2f(x)dx=\mu^2+\sigma^2 E(X2)=x2f(x)dx=μ2+σ2,正好对应积分第二项。又根据方差的定义可知 σ = ∫ ( x − μ ) d x \sigma=\int (x-\mu)dx σ=(xμ)dx,所以积分第三项为 − 1 -1 1

最终化简得到的结果如下:
K L ( q ( z ∣ x )    ∣ ∣    P ( z ) ) = K L ( N ( μ , σ 2 )    ∣ ∣    N ( 0 , 1 ) ) = 1 2 ( − l o g σ 2 + μ 2 + σ 2 − 1 ) KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=\dfrac{1}{2}(-log\sigma^2+\mu^2+\sigma^2-1) KL(q(zx)P(z))=KL(N(μ,σ2)N(0,1))=21(logσ2+μ2+σ21)

第二项,最大化期望。也就是表明在给定 q ( z ∣ x ) q(z|x) q(zx)(编码器输出)的情况下 P ( x ∣ z ) P(x|z) P(xz)(解码器输出)的值尽可能高。具体来讲,第一步,利用encoder的神经网络计算出均值与方差,从中采样得到 z z z,这一过程就对应式子中的 q ( z ∣ x ) q(z|x) q(zx);第二步,利用decoder的NN计算 z z z 的均值方差,让均值(或也考虑方差)越接近 x x x ,则产生 x x x 的几率 l o g P ( x ∣ z ) logP(x|z) logP(xz) 越大,对应于式子中的最大化 l o g P ( x ∣ z ) logP(x|z) logP(xz) 这一部分。

详解变分自编码器——VAE_第5张图片

推导至此完毕。

重参数技巧

最后模型在实现的时候,有一个重参数技巧,就是我们想从高斯分布 N ( μ , σ 2 ) \mathcal N(\mu,\sigma^2) N(μ,σ2) 中采样Z时,其实是相当于从 N ( 0 , 1 ) \mathcal N(0,1) N(0,1) 中采样一个 ϵ \epsilon ϵ,然后再来计算 Z = μ + ϵ × σ Z=\mu+\epsilon\times\sigma Z=μ+ϵ×σ。这么做的原因是,采样这个操作是不可导的,而采样的结果是可导的,这样做个参数变换, Z = μ + ϵ × σ Z=\mu+\epsilon\times\sigma Z=μ+ϵ×σ 这个就可以参与梯度下降,模型就可以训练了。

参考

  1. 苏剑林:变分自编码器(一):原来是这么一回事
  2. 李宏毅老师 Machine Learning (2017,秋,台湾大学) 国语

你可能感兴趣的:(机器学习,机器学习,生成模型,VAE,变分自编码器)