Auto-Encoding Variational Bayes 算法主要是针对连续隐含变量的统计推断模型,因为统计推断模型中常常会遇到后验概率分布形式比较难以获得以及样本数据过大等困难,所以作者便提出了用变分推断的方法,结合自动编码器来对以前的算法进行改进,其改进方向主要是证明了Stochastic gradient methods(类似于一种随机梯度下降算法)可以用来对变分下界进行最优估计;并且可以利用参数下界的最优估计来对模型的后验概率分布进行推断。
本文将跟随论文作者的思路,从变分方法的背景谈起,对比不同的算法,重点关注我们AEVB算法当中的SGVB(Stochastic Gradient Variational Bayes)估计方法。
根据作者的说法,AEVB算法对多种dataset都非常适用,为了解释方便,本文对样本集做出了如下的假设:
样本集的 X = { x ( i ) } i = 1 N X=\lbrace x^{(i)}\rbrace_{i=1}^N X={x(i)}i=1N是N个独立同分布的离散或者连续的样本,他们是通过某种随机过程产生的,我们假设他们的产生是源于一种我们目前无法观测到的随机变量z,z由如下过程产生:
(1)z的产生源于某种概率分布 p θ ∗ ( z ) p_{\theta^*}(z) pθ∗(z)。
(2)x的产生源自于某种条件概率分布 p θ ∗ ( x ∣ z ) p_{\theta^*}(x|z) pθ∗(x∣z)。
并且我们假设对于 θ \theta θ和z p θ ( z ) p_{\theta}(z) pθ(z)和 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(x∣z)都是几乎处处可导。但是实际上对于 θ \theta θ和z的的情况,我们时常都是无法获得的。
论文没有对其边缘分布或者后验概率分布做任何限制和假设,论文致力于提出一种比较通用的算法对
(1) p θ ( x ) = ∫ p θ ( z ) p θ ( x ∣ z ) p_{\theta}(x)=\int p_{\theta}(z)p_{\theta}(x|z) pθ(x)=∫pθ(z)pθ(x∣z)比较难以计算或者是后验概率 p θ ( z ∣ x ) = p θ ( x ∣ z ) p θ ( z ) p θ ( x ) p_{\theta}(z|x)=\frac {p_{\theta}(x|z)p_{\theta}(z)}{p_{\theta}(x)} pθ(z∣x)=pθ(x)pθ(x∣z)pθ(z)比较难以计算。这会导致似然估计的方法(likehood)以及EM算法,还有一些传统的VB算法全部无效。
(2)样本量过大,这会导致类似于Monte Carlo 算法等迭代速度过慢。
算法需要解决以下三个相关问题:
MLE和MAP算法都是对参数进行估计的一种算法。
MLE算法是对于似然函数进行一个估计:
θ M L E = a r g max θ P ( X ∣ θ ) = a r g max θ ∏ i P ( x i ∣ θ ) \theta_{MLE}=arg \max_{\theta} P(X|\theta) \\ =arg \max_{\theta} \prod_i P(x_i|\theta) θMLE=argθmaxP(X∣θ)=argθmaxi∏P(xi∣θ)
而MAP算法是对于其贝叶斯概率做一个估计
P ( θ ∣ X ) = P ( X ∣ θ ) P ( θ ) P ( X ) ∝ P ( X ∣ θ ) P ( θ ) P(\theta|X)=\frac{P(X|\theta)P(\theta)}{P(X)}\\ \propto P(X|\theta)P(\theta) P(θ∣X)=P(X)P(X∣θ)P(θ)∝P(X∣θ)P(θ)
所以 θ M A P = a r g max θ P ( X ∣ θ ) P ( θ ) = a r g max θ ∑ i log P ( x i ∣ θ ) + log P ( θ ) \theta_{MAP}=arg \max_{\theta} P(X|\theta)P(\theta)\\ = arg \max_{\theta} \sum_i \log P(x_i|\theta)+\log P(\theta) θMAP=argθmaxP(X∣θ)P(θ)=argθmaxi∑logP(xi∣θ)+logP(θ)
此方法出现在论文的2.1节当中,mean-field variational inference,它的核心思想也是用一个分布来近似得到 ϕ \phi ϕ的估计,与我们论文所述方法不同的是,此方法希望直接利用对KL散度 K L ( Q ∣ ∣ P ) = ∑ z ∈ Z q ϕ ( z ∣ x ) log q ϕ ( z ∣ x ) p ( x ) p ( z , x ) KL(Q||P)=\sum_{z\in Z}q_{\phi}(z|x)\log \frac{q_{\phi}(z|x)p(x)}{p(z,x)} KL(Q∣∣P)=∑z∈Zqϕ(z∣x)logp(z,x)qϕ(z∣x)p(x)的最优化获得结果。对于KL散度的参数最优化问题我们之前在变分推断读书笔记当中已经比较详细地讨论过,这里不再赘述了,我们本文的目标,是能够整体的一起推理出 ϕ \phi ϕ和 θ \theta θ以及他们之间的联系,这就是本文方法和mean-field variational inference方法的不同与改进。
针对变分问题的推理,我们在之前已经有过比较详细的叙述,这里提出了
log p θ ( x ( i ) ) = D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ( i ) ) + ζ ( θ , ϕ ; x ( i ) ) \log p_{\theta}(x^{(i)})=D_{KL}(q_{\phi}(z|x)||p_{\theta}(z|x^{(i)})+\zeta (\theta,\phi;x^{(i)}) logpθ(x(i))=DKL(qϕ(z∣x)∣∣pθ(z∣x(i))+ζ(θ,ϕ;x(i))
根据 ζ ( θ , ϕ ; x ( i ) ) \zeta (\theta,\phi;x^{(i)}) ζ(θ,ϕ;x(i))的分解定义,其分解为 ζ ( θ , ϕ ; x ( i ) ) = ∫ q ϕ ( z ∣ x ) ln p θ ( z ∣ x ( i ) ) q ϕ ( z ∣ x ) d Z \zeta (\theta,\phi;x^{(i)})=\int q_{\phi}(z|x)\ln \frac{p_{\theta}(z|x^{(i)})}{q_{\phi}(z|x)} dZ ζ(θ,ϕ;x(i))=∫qϕ(z∣x)lnqϕ(z∣x)pθ(z∣x(i))dZ 可以看出来其是关于 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)的期望 E q ϕ ( z ∣ x ) [ − log q ϕ ( z ∣ x ) + log p θ ( z ∣ x ( i ) ) ] E_{q_{\phi}(z|x)}[-\log q_{\phi}(z|x)+\log p_{\theta}(z|x^{(i)})] Eqϕ(z∣x)[−logqϕ(z∣x)+logpθ(z∣x(i))]
又根据KL散度的非负性,所以我们可以得到对于每个 p θ ( x ( i ) ) p_{\theta}(x^{(i)}) pθ(x(i))的一个下界 log p θ ( x ( i ) ) ≥ ζ ( θ , ϕ ; x ( i ) ) = E q ϕ ( z ∣ x ) [ − log q ϕ ( z ∣ x ) + log p θ ( z ∣ x ( i ) ) ] (1) \log p_{\theta}(x^{(i)})\ge \zeta (\theta,\phi;x^{(i)})=E_{q_{\phi}(z|x)}[-\log q_{\phi}(z|x)+\log p_{\theta}(z|x^{(i)})]\tag{1} logpθ(x(i))≥ζ(θ,ϕ;x(i))=Eqϕ(z∣x)[−logqϕ(z∣x)+logpθ(z∣x(i))](1)
同时根据 ζ ( θ , ϕ ; x ( i ) ) \zeta (\theta,\phi;x^{(i)}) ζ(θ,ϕ;x(i))本身的分解定义:
ζ ( θ , ϕ ; x ( i ) ) = ∫ q ϕ ( z ∣ x ) ( log p θ ( x ( i ) ∣ z ) + log p θ ( z ) − log q ϕ ( z ∣ x ( i ) ) ) d z \zeta (\theta,\phi;x^{(i)})=\int q_{\phi}(z|x)(\log p_{\theta}(x^{(i)}|z)+\log p_{\theta}(z)-\log q_{\phi}(z|x^{(i)}))dz ζ(θ,ϕ;x(i))=∫qϕ(z∣x)(logpθ(x(i)∣z)+logpθ(z)−logqϕ(z∣x(i)))dz
可以得到关于 ζ ( θ , ϕ ; x ( i ) ) \zeta (\theta,\phi;x^{(i)}) ζ(θ,ϕ;x(i))的另一个表述:
ζ ( θ , ϕ ; x ( i ) ) = − D K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ) ) + E q ϕ ( z ∣ x ( i ) ) [ log p θ ( x ( i ) ∣ z ) ] (2) \zeta (\theta,\phi;x^{(i)})=-D_{KL}(q_{\phi}(z|x^{(i)})||p_{\theta}(z))+E_{q_{\phi}(z|x^{(i)})}[\log p_{\theta}(x^{(i)}|z)] \tag{2} ζ(θ,ϕ;x(i))=−DKL(qϕ(z∣x(i))∣∣pθ(z))+Eqϕ(z∣x(i))[logpθ(x(i)∣z)](2)
这里的推导和假设建立在 p θ ( z ) = N ( 0 , I ) p_{\theta}(z)=N(0,I) pθ(z)=N(0,I)以及 q ϕ ( z ∣ x ( i ) ) q_{\phi}(z|x^{(i)}) qϕ(z∣x(i))是高斯分布的基础上。
因为 ∫ q θ ( z ) log p ( z ) = ∫ N ( z ; μ , σ 2 ) log N ( z ; 0 , I ) d z = − J 2 log ( 2 π ) − 1 2 ∑ j = 1 J ( μ j 2 + σ j 2 ) \int q_{\theta}(z)\log p(z)=\int N(z;\mu,\sigma^2)\log N(z;0,I)dz\\=-\frac{J}{2}\log(2\pi)-\frac{1}{2}\sum_{j=1}^J(\mu_j^2+\sigma_j^2) ∫qθ(z)logp(z)=∫N(z;μ,σ2)logN(z;0,I)dz=−2Jlog(2π)−21j=1∑J(μj2+σj2)
∫ q θ ( z ) log q θ ( z ) d z = ∫ N ( z ; μ , s i g m a 2 ) d z = − J 2 log ( 2 π ) − 1 2 ∑ j = 1 J ( 1 + log σ j 2 ) \int q_{\theta}(z)\log q_{\theta}(z)dz=\int N(z;\mu,sigma^2)dz\\=-\frac{J}{2}\log (2\pi)-\frac{1}{2}\sum_{j=1}^J(1+\log \sigma_j^2) ∫qθ(z)logqθ(z)dz=∫N(z;μ,sigma2)dz=−2Jlog(2π)−21j=1∑J(1+logσj2)
所以有:
− D K L ( ( q ϕ ( z ) ∣ ∣ p θ ( z ) ) = ∫ q θ ( z ) ( log p θ ( z ) − log q θ ( z ) ) d z = 1 2 ∑ j = 1 J ( 1 + log ( ( σ j 2 ) − ( μ j ) 2 − ( σ j ) 2 ) -D_{KL}((q_{\phi}(z)||p_{\theta}(z))=\int q_{\theta}(z)(\log p_{\theta}(z)-\log q_{\theta}(z))dz\\=\frac{1}{2}\sum_{j=1}^J(1+\log((\sigma_j^2)-(\mu_j)^2-(\sigma_j)^2) −DKL((qϕ(z)∣∣pθ(z))=∫qθ(z)(logpθ(z)−logqθ(z))dz=21j=1∑J(1+log((σj2)−(μj)2−(σj)2)
算法用了一种新的表述 z ~ = g ϕ ( ϵ , x ) \tilde{z}=g_{\phi}(\epsilon,x) z~=gϕ(ϵ,x),其中 ϵ ∼ p ( ϵ ) \epsilon\sim p(\epsilon) ϵ∼p(ϵ)
使用这样的新的表述的好处作者在2.4节进行了概述,这样表述就有:
q ϕ ( z ∣ x ) ∏ i d z i = p ( ϵ ) ∏ i d ϵ i q_{\phi}(z|x) \prod_idz_i=p(\epsilon)\prod_id\epsilon_i qϕ(z∣x)i∏dzi=p(ϵ)i∏dϵi
所以可以有:
∫ q ϕ ( z ∣ x ) f ( z ) d z = ∫ p ( ϵ ) f ( z ) d ϵ = ∫ p ( ϵ ) f ( g ϕ ( ϵ , x ) ) d ϵ \int q_{\phi}(z|x)f(z)dz=\int p(\epsilon)f(z)d{\epsilon}=\int p(\epsilon)f(g_{\phi}(\epsilon,x))d{\epsilon} ∫qϕ(z∣x)f(z)dz=∫p(ϵ)f(z)dϵ=∫p(ϵ)f(gϕ(ϵ,x))dϵ
利用这个写法我们可以得到一个估计:
∫ q ϕ ( z ∣ x ) f ( z ) d z ≃ 1 L ∑ l = 1 L f ( g ϕ ( x , ϵ ( l ) ) ) \int q_{\phi}(z|x)f(z)dz\simeq \frac{1}{L}\sum_{l=1}^{L}f(g_{\phi}(x,\epsilon^{(l)})) ∫qϕ(z∣x)f(z)dz≃L1l=1∑Lf(gϕ(x,ϵ(l)))
这里的话主要是用于对(1)式进行优化估计,有
E q ϕ ( z ∣ x ( i ) ) [ f ( z ) ] = E p ( ϵ ) [ f ( g ϕ ( ϵ , x ( i ) ) ) ] ≃ 1 L ∑ l = 1 L f ( g ϕ ( ϵ ( l ) , x ( i ) ) ) E_{q_{\phi}(z|x^{(i)})}[f(z)] =E_{p(\epsilon)}[f(g_{\phi}(\epsilon,x^{(i)}))]\simeq \frac{1}{L}\sum_{l=1}^Lf(g_{\phi}(\epsilon^{(l)},x^{(i)})) Eqϕ(z∣x(i))[f(z)]=Ep(ϵ)[f(gϕ(ϵ,x(i)))]≃L1l=1∑Lf(gϕ(ϵ(l),x(i)))其中 ϵ ∼ p ( ϵ ) \epsilon\sim p(\epsilon) ϵ∼p(ϵ)
在(1)式当中,我们可以对比发现这里的 f ( z ) f(z) f(z)相当于(1)当中的 − log q ϕ ( z ∣ x ) + log p θ ( z ∣ x ( i ) ) -\log q_{\phi}(z|x)+\log p_{\theta}(z|x^{(i)}) −logqϕ(z∣x)+logpθ(z∣x(i))
把对应的结果代入即可得到:
ζ ~ A = 1 L ∑ l = 1 L log p θ ( x ( i ) , z ( i , l ) ) − log q ϕ ( z ( i , l ) ∣ x ( i ) ) (3) \tilde{\zeta}^A=\frac{1}{L}\sum_{l=1}^L\log p_{\theta}(x^{(i)},z^{(i,l)})-\log q_{\phi}(z^{(i,l)}|x^{(i)})\tag{3} ζ~A=L1l=1∑Llogpθ(x(i),z(i,l))−logqϕ(z(i,l)∣x(i))(3)
上述是通过(1)式得到的一个下界的估计,实际上我们还可以通过对(2)的推导,同样也能得到一个比较好的结果。
ζ ~ B = − D K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ) ) + 1 L ∑ l = 1 L log ( p θ ( x ( i ) , z ( i , l ) ) ) ) (4) \tilde{\zeta}^B=-D_{KL}(q_{\phi}(z|x^{(i)})||p_{\theta}(z))+\frac{1}{L}\sum_{l=1}^L\log (p_{\theta}(x^{(i)},z^{(i,l)})))\tag{4} ζ~B=−DKL(qϕ(z∣x(i))∣∣pθ(z))+L1l=1∑Llog(pθ(x(i),z(i,l))))(4)
这种做法通常比(3)式得到的误差会小一些。
核心算法的步骤如下:
该算法的好处还有一点是有如下的近似式:
ζ ~ M ( θ , ϕ ; x M ) = N M ∑ i = 1 M ζ ( θ , ϕ ; x ( i ) ) \tilde{\zeta}^M (\theta,\phi;x^{M})=\frac{N}{M}\sum_{i=1}^M\zeta (\theta,\phi;x^{(i)}) ζ~M(θ,ϕ;xM)=MNi=1∑Mζ(θ,ϕ;x(i))
我们可以根据上式,选择一个样本集,来对于全局进行估计。
变分自动编码器的应用选择了一个高斯分布来模拟z的产生,即z服从一个 N ( z ; 0 , I ) N(z;0,I) N(z;0,I)。选择 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(x∣z)是一个混合高斯分布或者伯努利分布。而因为实际的 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x)很难获得,所以选择了 q θ ( x ∣ z ) q_{\theta}(x|z) qθ(x∣z)来对其进行近似,在例子当中我们对其选择了 log q ϕ ( z ∣ x ( i ) ) = log N ( z ; μ ( i ) , σ ( i ) I ) \log q_{\phi}(z|x^{(i)})=\log N(z;\mu^{(i)},\sigma^{(i)}I) logqϕ(z∣x(i))=logN(z;μ(i),σ(i)I),文章用了全连接网络来对于变分自动编码器进行模拟,其网络设置如下:
这个例子中在对 ζ ( θ , ϕ ; x ( i ) ) \zeta (\theta,\phi;x^{(i)}) ζ(θ,ϕ;x(i))做估计的时候使用了(4)的算法,从之前的数学推导当中可以看出(4)的算法是相对比较好求的。
本文的代码可以在github上面找到,链接如下:https://github.com/hwalsuklee/tensorflow-mnist-VAE
实际上本算法的核心是通过SGVB方法来对神经网络当中的loss函数进行计算,便于去更新参数值。
本算法可以作为生成器,重现输入的样本,效果图如下:
本文的主要创新点在于,直接采取了一种 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)的方式近似 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(z∣x),而且 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x)并不是直接去通过最优化其KL散度的方式来获得,因为直接优化其KL散度可能会出现梯度无法求解的情况。本文也不是通过蒙特卡洛算法来实现的,因为蒙特卡洛算法过于粗暴,其有一项类似于 f ( z ) ▽ q ϕ ( z ( l ) ) log q ϕ ( z ( l ) ) f(z)\bigtriangledown_{q_{\phi}(z^{(l)})}\log q_{\phi}(z^{(l)}) f(z)▽qϕ(z(l))logqϕ(z(l))的项,其方差(Hessian矩阵?)的计算复杂度极大,实际情况下效果非常差,所以我们VAE方法利用一种reparameterization的方式,巧妙地解决了这一问题。这应该是本文的核心,也是对学术界的巨大贡献之处。