从EM算法到变分推断(变分贝叶斯推断)

在含有隐变量 Z Z Z的图模型中,在用EM迭代求解时,需要计算一个后验分布:
p ( Z ∣ Y , θ o l d ) p(Z|Y,\theta_{old}) p(ZY,θold),如果Z是离散变量,且分布比较简单则可以直接求出该后验分布的解析解,比如在三硬币模型中和高斯混合模型中。然而如果 Z Z Z服从一些复杂离散分布或者连续分布,后验分布 p ( Z ∣ Y , θ o l d ) p(Z|Y,\theta_{old}) p(ZY,θold)变得异常难求。

此时采用近似推断,有两类想法求该后验分布,一是随机化方法近似,比如采样法,二是确定性方法近似,比如变分推断求解LDA。

在变分推断中,经常用一个简单分布 q ( Z ; φ ) q(Z;\varphi) q(Z;φ)来近似复杂分布 p ( Z ∣ Y , θ o l d ) p(Z|Y,\theta_{old}) p(ZY,θold),从而得到一个局部最优,但是有确定解的近似后验分布(平均场思想)。
一般假设 q ( Z ) = ∏ i = 1 M q i ( Z i ) q(Z)=\prod\limits_{i=1}^{M}q_i(Z_i) q(Z)=i=1Mqi(Zi),也就是把 Z Z Z分成相互独立的 M M M个部分,而每个部分的分布 q i ( Z i ) q_i(Z_i) qi(Zi)都相对简单。此时利用变分法求 q i ( Z i ) q_i(Z_i) qi(Zi)
log ⁡ p ( Y ) = log ⁡ p ( Y , Z ) − log ⁡ p ( Z ∣ Y ) \log p(Y)=\log p(Y,Z)-\log p(Z|Y) logp(Y)=logp(Y,Z)logp(ZY)
log ⁡ p ( Y ) = log ⁡ p ( Y , Z ) q ( Z ) − log ⁡ p ( Z ∣ Y ) q ( Z ) \log p(Y)=\log \frac{p(Y,Z)}{q(Z)}-\log \frac{p(Z|Y)}{q(Z)} logp(Y)=logq(Z)p(Y,Z)logq(Z)p(ZY)
两边积分
∫ q ( Z ) log ⁡ p ( Y ) d Z = ∫ q ( Z ) log ⁡ p ( Y , Z ) q ( Z ) d Z − ∫ q ( Z ) log ⁡ p ( Z ∣ Y ) q ( Z ) d Z \int q(Z)\log p(Y)dZ=\int q(Z)\log \frac{p(Y,Z)}{q(Z)}dZ-\int q(Z)\log \frac{p(Z|Y)}{q(Z)}dZ q(Z)logp(Y)dZ=q(Z)logq(Z)p(Y,Z)dZq(Z)logq(Z)p(ZY)dZ
log ⁡ p ( Y ) = ∫ q ( Z ) log ⁡ p ( Y , Z ) q ( Z ) d Z − ∫ q ( Z ) log ⁡ p ( Z ∣ Y ) q ( Z ) d Z \log p(Y)=\int q(Z)\log \frac{p(Y,Z)}{q(Z)}dZ-\int q(Z)\log \frac{p(Z|Y)}{q(Z)}dZ logp(Y)=q(Z)logq(Z)p(Y,Z)dZq(Z)logq(Z)p(ZY)dZ
log ⁡ p ( Y ) = L ( q ) + K L ( q ∣ ∣ p ) \log p(Y)=L(q)+KL(q||p) logp(Y)=L(q)+KL(qp)
Evidence lower bound (ELBO):
L ( q ) = ∫ q ( Z ) log ⁡ p ( Y , Z ) q ( Z ) d Z L(q)=\int q(Z)\log \frac{p(Y,Z)}{q(Z)}dZ L(q)=q(Z)logq(Z)p(Y,Z)dZ
KL divergence:
K L ( q ∣ ∣ p ) = ∫ q ( Z ) log ⁡ q ( Z ) p ( Z ∣ Y ) d Z = − ∫ q ( Z ) log ⁡ p ( Z ∣ Y ) q ( Z ) d Z KL(q||p)=\int q(Z) \log \frac{q(Z)}{p(Z|Y)}dZ=-\int q(Z) \log \frac{p(Z|Y)}{q(Z)}dZ KL(qp)=q(Z)logp(ZY)q(Z)dZ=q(Z)logq(Z)p(ZY)dZ
因此最大化对数似然函数或者最大化ELBO: L ( q ) L(q) L(q)或者最小化KL divergence: K L ( q ∣ ∣ p ) KL(q||p) KL(qp)
在最大化 L ( q ) L(q) L(q)过程中,优化变量是函数 q ( Z ) q(Z) q(Z),因此采用变分法,所以叫变分推断。变分推断依然基于EM框架,属于一种特殊的高阶的EM算法。

你可能感兴趣的:(机器学习,统计)