Contrastive Divergence:一种结合变分推断与MCMC的方法

本文是对A Contrastive Divergence for Combining Variational Inference and MCMC的笔记整理。

Introduction

这篇文章是将VI和MCMC结合的一篇很有意思的文章。他的基本思想沿用了RBM中的Contrastive Divergence方法。他通过采用MCMC的方法来对变分函数q进行抽样从而得到更加准确的变分函数,然后再以此更新变分函数的参数,以次迭代直到收敛。

我们先回顾一下变分推断,更详细的介绍可以看我之前的两篇文章

  1. 带你理解EM算法
  2. 变分自编码器(VAE)

简单来说,当存在隐变量时,我们只能得到数据的marginal likelihood,此时,模型的优化就变得很困难,为此,我们希望能够找到一个分布 q θ ( z ) \displaystyle q_{\theta } (z) qθ(z)来近似真实的分布 p ( z ∣ x ) \displaystyle p (z|x) p(zx),于是我们发现最小化散度 K L ( q θ ( z ) ∥ p ( z ∣ x ) ) \displaystyle KL( q_{\theta } (z)\| p (z|x)) KL(qθ(z)p(zx))等价于最大化下界ELBO的函数:

L  standard  ( θ ) = E q θ ( z ) [ f θ ( z ) ] = − D K L ( q θ ( z ) ∥ p ( z ) ) + E q θ ( z ) [ log ⁡ p ( x ∣ z ) ] \mathcal{L}_{\text{ standard }} (\theta )=\mathbb{E}_{q_{\theta } (z)}[ f_{\theta } (z)] =-D_{KL}( q_{\boldsymbol{\theta }}(\mathbf{z}) \| p (\mathbf{z} )) +\mathbb{E}_{q_{\boldsymbol{\theta }}(\mathbf{z})}[\log p(\mathbf{x} |\mathbf{z})] L standard (θ)=Eqθ(z)[fθ(z)]=DKL(qθ(z)p(z))+Eqθ(z)[logp(xz)]

其中

f θ ( z ) ≜ log ⁡ p ( x , z ) − log ⁡ q θ ( z ) . f_{\theta } (z)\triangleq \log p(x,z)-\log q_{\theta } (z). fθ(z)logp(x,z)logqθ(z).

但是变分的缺陷在于,那个近似的分布q是无法保证近似的效果的,很有可能近似地很糟糕,与此同时,MCMC只要马尔科夫链长度足够,我们是一定能够恢复出真实的后验概率的,但是MCMC的速度又太慢了。所以,我们能不能将VI学到的 q θ ( z ) \displaystyle q_{\theta } (z) qθ(z)作为MCMC的初始值,用MCMC迭代几轮,让他长得更像 p ( z ∣ x ) \displaystyle p (z|x) p(zx),然后再回去更新近似分布的参数 q θ ( z ) \displaystyle q_{\theta } (z) qθ(z),如此迭代地来即解决了变分近似不准的问题,又解决了MCMC速度慢的问题。

方法

我们用
q θ ( t ) ( z ) = ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) d z 0 q^{(t)}_{\theta } (z)=\int Q^{(t)} (z|z_{0} )q_{\theta } (z_{0} )dz_{0} qθ(t)(z)=Q(t)(zz0)qθ(z0)dz0

来表示使用 z 0 z_0 z0作为初始值,并经过t次 Q ( t ) ( z ∣ z 0 ) \displaystyle Q^{(t)} (z|z_{0} ) Q(t)(zz0)的转移而达到的分布记为 q θ ( t ) ( z ) \displaystyle q^{(t)}_{\theta } (z) qθ(t)(z)。注意这个分布的解析解是无法计算的,他只是可以表示成这样而已。用这个新的q,我们的变分下界可以改进成:

L  improved  ( θ ) = E q θ ( t ) ( z ) [ log ⁡ p ( x , z ) − log ⁡ q θ ( t ) ( z ) ] \mathcal{L}_{\text{ improved }} (\theta )=\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log p(x,z)-\log q^{(t)}_{\theta } (z)\right] L improved (θ)=Eqθ(t)(z)[logp(x,z)logqθ(t)(z)]

但问题是,现在 θ \displaystyle \theta θ没法求导。一般的求导方法有两种,一种是重参数化,另一种是则是reinforce,像这样的:

∇ ϕ E q ϕ ( z ) [ f ( z ) ] = E q ϕ ( z ) [ f ( z ) ∇ q ϕ ( z ) log ⁡ q ϕ ( z ) ] ≃ 1 L ∑ l = 1 L f ( z ) ∇ q ϕ ( z ( l ) ) log ⁡ q ϕ ( z ( l ) ) \nabla _{\phi }\mathbb{E}_{q_{\phi } (\mathbf{z} )} [f(\mathbf{z} )]=\mathbb{E}_{q_{\phi } (\mathbf{z} )}[ f(\mathbf{z} )\nabla _{q_{\phi } (\mathbf{z} )}\log q_{\phi } (\mathbf{z} )] \simeq \frac{1}{L}\sum ^{L}_{l=1} f(\mathbf{z} )\nabla _{q_{\phi }\left(\mathbf{z}^{(l)}\right)}\log q_{\phi }\left(\mathbf{z}^{(l)}\right) ϕEqϕ(z)[f(z)]=Eqϕ(z)[f(z)qϕ(z)logqϕ(z)]L1l=1Lf(z)qϕ(z(l))logqϕ(z(l))

然而现在,因为 q θ ( t ) ( z ) \displaystyle q^{(t)}_{\theta } (z) qθ(t)(z)的实际分布我们不知道,于是他概率密度的值 log ⁡ q θ ( t ) ( z ) \displaystyle \log q^{(t)}_{\theta } (z) logqθ(t)(z)是没法算的,那自然梯度也没法算,怎么办呢?能不能找到一个新的“下界”,使得下界永远是大于等于0的,且下界为0时恰好有 q θ ( z ) = p ( z ∣ x ) \displaystyle q_{\theta } (z)=p( z|x) qθ(z)=p(zx),更重要的是,不需要计算 log ⁡ q θ ( t ) ( z ) \displaystyle \log q^{( t)}_{\theta }(\mathbf{z}) logqθ(t)(z)。为此,该文发现一个三角不等式

KL ⁡ ( q θ ( z ) ∥ p ( z ∣ x ) ) + K L ( q θ ( t ) ( z ) ∥ q θ ( z ) ) ≥ K L ( q θ ( t ) ( z ) ∥ p ( z ∣ x ) ) ⟹ KL ⁡ ( q θ ( z ) ∥ p ( z ∣ x ) ) − K L ( q θ ( t ) ( z ) ∥ p ( z ∣ x ) ) ⎵ L diff ( θ ) + K L ( q θ ( t ) ( z ) ∥ q θ ( z ) ) ≥ 0 \operatorname{KL}( q_{\theta } (z)\| p(z|x)) +\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| q_{\theta } (z)\right) \geq \mathrm{KL}\left( q^{(t)}_{\theta } (z)\| p(z|x)\right)\\ \Longrightarrow \underbrace{\operatorname{KL}( q_{\theta } (z)\| p(z|x)) -\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| p(z|x)\right)}_{\mathcal{L}_{\text{diff}} (\theta )} +\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| q_{\theta } (z)\right) \geq 0 KL(qθ(z)p(zx))+KL(qθ(t)(z)qθ(z))KL(qθ(t)(z)p(zx))Ldiff(θ) KL(qθ(z)p(zx))KL(qθ(t)(z)p(zx))+KL(qθ(t)(z)qθ(z))0

注意,这个三角不等式成立的原因是因为 q θ ( t ) ( z ) \displaystyle q^{(t)}_{\theta } (z) qθ(t)(z)一定更接近 p ( z ∣ x ) \displaystyle p(z|x) p(zx)。于是我们定义

L V C D ( θ ) ≜ L d i f f ( θ ) + K L ( q θ ( t ) ( z ) ∥ q θ ( z ) ) \mathcal{L}_{\mathrm{VCD}} (\theta )\triangleq \mathcal{L}_{\mathrm{diff}} (\theta )+\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| q_{\theta } (z)\right) LVCD(θ)Ldiff(θ)+KL(qθ(t)(z)qθ(z))

这个新的目标函数很有意思,当 q θ ( z ) = p ( z ∣ x ) \displaystyle q_{\theta } (z)=p(z|x) qθ(z)=p(zx)时,一定有 q θ ( t ) ( z ) = p ( z ∣ x ) \displaystyle q^{(t)}_{\theta } (z)=p(z|x) qθ(t)(z)=p(zx)于是 L V C D = 0 \mathcal{L}_{\mathrm{VCD}}=0 LVCD=0,而且根据三角不等式,他一定是大于0的,所以我们完全可以用这个来代替下界的目标函数,而且关键地方在于,他不需要计算 log ⁡ q θ ( t ) ( z ) \displaystyle \log q^{( t)}_{\theta }(\mathbf{z}) logqθ(t)(z),看如下推导:

L V C D ( θ ) = KL ⁡ ( q θ ( z ) ∥ p ( z ∣ x ) ) − K L ( q θ ( t ) ( z ) ∥ p ( z ∣ x ) ) + K L ( q θ ( t ) ( z ) ∥ q θ ( z ) ) = E q θ ( z ) [ log ⁡ q θ ( z ) p ( z ∣ x ) ] − E q θ ( t ) ( z ) [ log ⁡ q θ ( t ) ( z ) p ( z ∣ x ) ] + E q θ ( t ) ( z ) [ log ⁡ q θ ( t ) ( z ) q θ ( z ) ] = E q θ ( z ) [ log ⁡ q θ ( z ) p ( z ∣ x ) ] − E q θ ( t ) ( z ) [ log ⁡ q θ ( z ) p ( z ∣ x ) ] = E q θ ( z ) [ log ⁡ q θ ( z ) p ( x ) p ( z , x ) ] − E q θ ( t ) ( z ) [ log ⁡ q θ ( z ) p ( x ) p ( z , x ) ] = E q θ ( z ) [ log ⁡ q θ ( z ) p ( z , x ) ] − E q θ ( t ) ( z ) [ log ⁡ q θ ( z ) p ( z , x ) ] = − E q θ ( z ) [ f θ ( z ) ] + E q θ ( t ) ( z ) [ f θ ( z ) ] \begin{aligned} \mathcal{L}_{\mathrm{VCD}} (\theta )= & \operatorname{KL}( q_{\theta } (z)\| p(z|x)) -\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| p(z|x)\right) +\mathrm{KL}\left( q^{(t)}_{\theta } (z)\| q_{\theta } (z)\right)\\ = & \mathbb{E}_{q_{\theta } (z)}\left[\log\frac{q_{\theta } (z)}{p(z|x)}\right] -\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log\frac{q^{(t)}_{\theta } (z)}{p(z|x)}\right] +\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log\frac{q^{(t)}_{\theta } (z)}{q_{\theta } (z)}\right]\\ = & \mathbb{E}_{q_{\theta } (z)}\left[\log\frac{q_{\theta } (z)}{p(z|x)}\right] -\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log\frac{q_{\theta } (z)}{p(z|x)}\right]\\ = & \mathbb{E}_{q_{\theta } (z)}\left[\log\frac{q_{\theta } (z)p( x)}{p(z,x)}\right] -\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log\frac{q_{\theta } (z)p( x)}{p(z,x)}\right]\\ = & \mathbb{E}_{q_{\theta } (z)}\left[\log\frac{q_{\theta } (z)}{p(z,x)}\right] -\mathbb{E}_{q^{(t)}_{\theta } (z)}\left[\log\frac{q_{\theta } (z)}{p(z,x)}\right]\\ = & -\mathbb{E}_{q_{\theta } (z)}[ f_{\theta } (z)] +\mathbb{E}_{q^{(t)}_{\theta } (z)}[ f_{\theta } (z)] \end{aligned} LVCD(θ)======KL(qθ(z)p(zx))KL(qθ(t)(z)p(zx))+KL(qθ(t)(z)qθ(z))Eqθ(z)[logp(zx)qθ(z)]Eqθ(t)(z)[logp(zx)qθ(t)(z)]+Eqθ(t)(z)[logqθ(z)qθ(t)(z)]Eqθ(z)[logp(zx)qθ(z)]Eqθ(t)(z)[logp(zx)qθ(z)]Eqθ(z)[logp(z,x)qθ(z)p(x)]Eqθ(t)(z)[logp(z,x)qθ(z)p(x)]Eqθ(z)[logp(z,x)qθ(z)]Eqθ(t)(z)[logp(z,x)qθ(z)]Eqθ(z)[fθ(z)]+Eqθ(t)(z)[fθ(z)]

神奇的事情发生了,在期望里面最讨厌的 log ⁡ q θ ( t ) ( z ) \displaystyle \log q^{(t)}_{\theta } (z) logqθ(t)(z)被消去了,接下来事情好办了,对于梯度,第一项的梯度我们可以用传统的方法解决,比如重参数化,对于第二项的梯度,可以这样算:

∇ θ E q θ ( t ) ( z ) [ f θ ( z ) ] = ∫ q θ ( t ) ( z ) × ∇ θ f θ ( z ) d z + ∫ ∇ θ q θ ( t ) ( z ) × f θ ( z ) d z = − ∫ q θ ( t ) ( z ) × ∇ θ log ⁡ q θ ( z 0 ) d z + ∫ ( ∇ θ ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) d z 0 ) × f θ ( z ) d z = − E q θ ( t ) ( z ) [ ∇ θ log ⁡ q θ ( z ) ] + ∫ ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) ∇ θ log ⁡ q θ ( z 0 ) d z 0 × f θ ( z ) d z = − E q θ ( t ) ( z ) [ ∇ θ log ⁡ q θ ( z ) ] + E q θ ( z 0 ) [ E Q ( t ) ( z ∣ z 0 ) [ f θ ( z ) ] ∇ θ log ⁡ q θ ( z 0 ) ] \begin{aligned} \nabla _{\theta }\mathbb{E}_{q^{(t)}_{\theta } (z)}[ f_{\theta } (z)] & =\int q^{(t)}_{\theta } (z)\times \nabla _{\theta } f_{\theta } (z)dz+\int \nabla _{\theta } q^{(t)}_{\theta } (z)\times f_{\theta } (z)dz\\ & =-\int q^{(t)}_{\theta } (z)\times \nabla _{\theta }\log q_{\theta }( z_{0}) dz+\int \left( \nabla _{\theta }\int Q^{(t)} (z|z_{0} )q_{\theta } (z_{0} )dz_{0}\right) \times f_{\theta } (z)dz\\ & =-\mathbb{E}_{q^{(t)}_{\theta } (z)}[ \nabla _{\theta }\log q_{\theta } (z)] +\int \int Q^{(t)} (z|z_{0} )q_{\theta }( z_{0}) \nabla _{\theta }\log q_{\theta } (z_{0} )dz_{0} \times f_{\theta } (z)dz\\ & =-\mathbb{E}_{q^{(t)}_{\theta } (z)}[ \nabla _{\theta }\log q_{\theta } (z)] +\mathbb{E}_{q_{\theta }( z_{0})}[\mathbb{E}_{Q^{(t)}( z|z_{0})}[ f_{\theta } (z)] \nabla _{\theta }\log q_{\theta }( z_{0})] \end{aligned} θEqθ(t)(z)[fθ(z)]=qθ(t)(z)×θfθ(z)dz+θqθ(t)(z)×fθ(z)dz=qθ(t)(z)×θlogqθ(z0)dz+(θQ(t)(zz0)qθ(z0)dz0)×fθ(z)dz=Eqθ(t)(z)[θlogqθ(z)]+Q(t)(zz0)qθ(z0)θlogqθ(z0)dz0×fθ(z)dz=Eqθ(t)(z)[θlogqθ(z)]+Eqθ(z0)[EQ(t)(zz0)[fθ(z)]θlogqθ(z0)]

第二个等于号,首先第一项因为 p ( x , z ) \displaystyle p( x,z) p(x,z)与参数 θ \displaystyle \theta θ无关,所以 ∇ θ f θ ( z ) = ∇ θ log ⁡ q θ ( z 0 ) \displaystyle \nabla _{\theta } f_{\theta } (z)=\nabla _{\theta }\log q_{\theta }( z_{0}) θfθ(z)=θlogqθ(z0),针对第二项,根据定义 q θ ( t ) ( z ) = ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) d z 0 \displaystyle q^{(t)}_{\theta } (z)=\int Q^{(t)} (z|z_{0} )q_{\theta } (z_{0} )dz_{0} qθ(t)(z)=Q(t)(zz0)qθ(z0)dz0代进去得到。第三个等于号是因为

∇ θ q θ ( t ) ( z ) = ∇ θ ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) d z 0 = ∫ Q ( t ) ( z ∣ z 0 ) q θ ( z 0 ) ∇ θ log ⁡ q θ ( z 0 ) d z 0 \begin{array}{ c l } \nabla _{\theta } q^{(t)}_{\theta } (z) & =\nabla _{\theta }\int Q^{(t)}( z|z_{0}) q_{\theta }( z_{0}) dz_{0}\\ & =\int Q^{(t)}( z|z_{0}) q_{\theta }( z_{0}) \nabla _{\theta }\log q_{\theta }( z_{0}) dz_{0} \end{array} θqθ(t)(z)=θQ(t)(zz0)qθ(z0)dz0=Q(t)(zz0)qθ(z0)θlogqθ(z0)dz0

于是这个梯度公式是完全可以用蒙特卡洛计算的,即我们 q θ ( t ) ( z ) \displaystyle q^{(t)}_{\theta } (z) qθ(t)(z)的样本可以用MCMC得到,只需先采样 z o ∼ q θ ( z ) \displaystyle z_{o} \sim q_{\theta }( z) zoqθ(z),然后跑t次MCMC,得到 z ∼ Q ( t ) ( z ∣ z 0 ) \displaystyle z\sim Q^{(t)}( z|z_{0}) zQ(t)(zz0)就可以了。

然而这样做得话, θ \displaystyle \theta θ的梯度方差显然比重参数化技术来的要大,又因为,随着t增加,参数 θ \displaystyle \theta θ一定会跟采样的分布越来越独立,于是可以设置一个递减的参数C来解决这个问题。完整的算法如下:

Contrastive Divergence:一种结合变分推断与MCMC的方法_第1张图片

参考资料

A Contrastive Divergence for Combining Variational Inference and MCMC

你可能感兴趣的:(机器学习,人工智能)