变分推断variational Inference

十一、变分推断

1.背景

{ 频 率 角 度 , 优 化 问 题 { 回 归 { M o d e l 策 略 算 法 { 解 析 解 数 值 解 S V M E M 等 等 贝 叶 斯 角 度 , 积 分 问 题 { 贝 叶 斯 I n f e r n e c e ( 求 贝 叶 斯 后 验 ) P ( θ ∣ x ) = P ( x ∣ θ ) P ( θ ) P ( x ) 贝 叶 斯 决 策 ( 预 测 , 最 后 还 是 求 贝 叶 斯 后 验 ) P ( x ~ ∣ x ) = ∫ θ P ( x ~ , θ ∣ x ) d θ = ∫ θ P ( x ~ ∣ θ ) P ( θ ∣ x ) d θ = E θ ∣ x [ P ( x ~ ∣ θ ) ] \begin{cases} 频率角度,优化问题 \begin{cases} 回归 \begin{cases} Model\\ 策略\\ 算法 \begin{cases} 解析解\\ 数值解 \end{cases} \end{cases}\\ SVM\\ EM\\ 等等 \end{cases}\\ 贝叶斯角度,积分问题 \begin{cases} 贝叶斯Infernece(求贝叶斯后验)\\ P(\theta \mid x)=\frac{P(x \mid \theta)P(\theta)}{P(x)}\\ 贝叶斯决策(预测,最后还是求贝叶斯后验)\\ P(\tilde{x} \mid x)=\int_{\theta}P(\tilde{x},\theta \mid x) d\theta=\int_{\theta}P(\tilde{x} \mid \theta)P(\theta \mid x)d\theta=E_{\theta \mid x}[P(\tilde{x} \mid \theta)] \end{cases} \end{cases} Model{ SVMEMInfernece()P(θx)=P(x)P(xθ)P(θ)()P(x~x)=θP(x~,θx)dθ=θP(x~θ)P(θx)dθ=Eθx[P(x~θ)]

I n f e r e n c e { 精 确 推 断 ( 后 验 简 单 ) 近 似 推 断 / 近 似 推 断 的 期 望 ( 参 数 空 间 、 隐 变 量 非 常 复 杂 ) { 确 定 性 近 似 → V I 随 机 近 似 → M C M C , M H , G i b b s Inference \begin{cases} 精确推断(后验简单)\\ 近似推断/近似推断的期望(参数空间、隐变量非常复杂)\\ \begin{cases} 确定性近似\to VI\\ 随机近似 \to MCMC,MH,Gibbs \end{cases} \end{cases} Inference/(){ VIMCMC,MH,Gibbs

2.公式推导

x x x:observed data
z z z:later variable + parameter
( x , z ) (x,z) (x,z):complete data

ELBO + KL
log ⁡ P ( x ) = L ( q ) + K L ( q ∣ ∣ p ) \log P(x)= L(q)+KL(q||p) logP(x)=L(q)+KL(qp)
q ^ ( z ) = arg ⁡ max ⁡ q ( z ) L ( q ) → q ^ ( z ) ≈ p ( z ∣ x ) \hat q(z)=\arg \max_{q(z)} L(q) \to \hat q(z) \approx p(z \mid x) q^(z)=argmaxq(z)L(q)q^(z)p(zx)

基于物理的平均场理论
q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^M q_i(z_i) q(z)=i=1Mqi(zi),计算时固定一维 q j ( z j ) q_j(z_j) qj(zj)
L ( q ) = ∫ z q ( z ) log ⁡ P ( x , z ) d z − ∫ z q ( z ) log ⁡ q ( z ) d z L(q)=\int_z q(z) \log P(x,z)dz-\int_z q(z)\log q(z)dz L(q)=zq(z)logP(x,z)dzzq(z)logq(z)dz

∫ z q ( z ) log ⁡ P ( x , z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log ⁡ P ( x , z ) d z = ∫ z j q j ( z j ) d z j ( ∫ z i ∏ i M q i ( z i ) log ⁡ P ( x , z ) d z i ) ( i ≠ j ) = ∫ z j q j ( z j ) d z j E ∏ i M q i ( z i ) [ log ⁡ P ( x , z ) ] ∫ z q ( z ) log ⁡ q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log ⁡ ∏ i = 1 M q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M log ⁡ q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ log ⁡ q 1 ( z 1 ) + log ⁡ q 2 ( z 2 ) + ⋯ + log ⁡ q M ( z M ) ] d z = ∑ i = 1 M ∫ z i q i ( z i ) log ⁡ q i ( z i ) d z i = ∫ z j q j ( z j ) log ⁡ q j ( z j ) d z j + C L ( q ) = ∫ z j q j ( z j ) log ⁡ p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ∣ ∣ p ^ ( x , z j ) ) ≤ 0 \begin{aligned} \int_z q(z) \log P(x,z)dz &=\int_z \prod_{i=1}^M q_i(z_i) \log P(x,z)dz\\ &=\int_{z_j} q_j(z_j) dz_j\left ( \int_{z_i} \prod_{i}^M q_i(z_i) \log P(x,z) dz_i \right )(i \ne j)\\ &=\int_{z_j} q_j(z_j) dz_j E_{\prod_{i}^{M} q_i(z_i)}[\log P(x,z)]\\ \int_z q(z) \log q(z) dz &=\int_z \prod_{i=1}^{M}q_i(z_i) \log \prod_{i=1}^{M}q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) \sum_{i=1}^{M} \log q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) [ \log q_1(z_1)+\log q_2(z_2)+\cdots+\log q_M(z_M) ] dz\\ &=\sum_{i=1}^{M} \int_{z_i}q_i(z_i) \log q_i(z_i)dz_i\\ &=\int_{z_j}q_j(z_j) \log q_j(z_j)dz_j + C\\ L(q)&=\int_{z_j} q_j(z_j)\log \frac{\hat p(x,z_j)}{q_j(z_j)}dz_j\\ &=-KL(q_j||\hat p(x,z_j)) \le0 \end{aligned} zq(z)logP(x,z)dzzq(z)logq(z)dzL(q)=zi=1Mqi(zi)logP(x,z)dz=zjqj(zj)dzj(ziiMqi(zi)logP(x,z)dzi)(i=j)=zjqj(zj)dzjEiMqi(zi)[logP(x,z)]=zi=1Mqi(zi)logi=1Mqi(zi)dz=zi=1Mqi(zi)i=1Mlogqi(zi)dz=zi=1Mqi(zi)[logq1(z1)+logq2(z2)++logqM(zM)]dz=i=1Mziqi(zi)logqi(zi)dzi=zjqj(zj)logqj(zj)dzj+C=zjqj(zj)logqj(zj)p^(x,zj)dzj=KL(qjp^(x,zj))0

用EM算法求解含隐变量的极大似然估计,极大似然估计是关于后验概率的函数,将不等式划等号q=p,一般p很复杂不易求得,最小化KL可以得到最优解 q ^ \hat q q^,基于平均场理论,使用相互独立的 ∏ i = 1 M q i ( z i ) \prod_{i=1}^{M}q_i(z_i) i=1Mqi(zi)近似推断后验p

3.再回首

VI (mean field) → \to Classical VI

Assumption: q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_i(z_i) q(z)=i=1Mqi(zi)
log ⁡ q j ( z j ) = E ∏ i q i ( z i ) [ log ⁡ P θ ( x ( i ) , z ) ] + C = ∫ q 1 ⋯ ∫ q j − 1 ∫ q j + 1 ⋯ ∫ q M q 1 ⋯ q j − 1 q j + 1 ⋯ q M [ log ⁡ P θ ( x ( i ) ) , z ] d q 1 ⋯ d q j − 1 d q j + 1 ⋯ d q M \begin{aligned} \log q_j(z_j) &=E_{\prod_i q_i(z_i)}[\log P_{\theta}(x^{(i)},z)]+C\\ &=\int_{q_1}\cdots\int_{q_{j-1}} \int_{q_{j+1}} \cdots \int_{q_M}q_1 \cdots q_{j-1}q_{j+1} \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_1\cdots dq_{j-1} dq_{j+1} \cdots dq_M \end{aligned} logqj(zj)=Eiqi(zi)[logPθ(x(i),z)]+C=q1qj1qj+1qMq1qj1qj+1qM[logPθ(x(i)),z]dq1dqj1dqj+1dqM

目标函数:
q ^ = arg ⁡ min ⁡ q K L ( q ∣ ∣ p ) = arg ⁡ max ⁡ q L ( q ) q ^ 1 ( z 1 ) = ∫ q 2 ⋯ ∫ q M q 2 ⋯ q M [ log ⁡ P θ ( x ( i ) ) , z ] d q 2 ⋯ d q M q ^ 2 ( z 2 ) = ∫ q ^ 1 ⋯ ∫ q M q ^ 1 ⋯ q M [ log ⁡ P θ ( x ( i ) ) , z ] d q ^ 1 ⋯ d q M q ^ M ( z M ) = ∫ q ^ 1 ∫ q ^ 2 ⋯ ∫ q ^ M − 1 q ^ 1 q ^ 2 ⋯ q ^ M − 1 [ log ⁡ P θ ( x ( i ) ) , z ] d q ^ 1 q ^ 2 ⋯ d q ^ M − 1 \begin{aligned} &\hat q = \arg \min_q KL(q||p)=\arg \max_q L(q)\\ &\hat q_1(z_1) =\int_{q_2} \cdots \int_{q_M}q_2 \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_2 \cdots dq_M\\ &\hat q_2(z_2) =\int_{\hat q_1} \cdots \int_{q_M}\hat q_1 \cdots q_M [\log P_{\theta}(x^{(i)}),z]d\hat q_1 \cdots dq_M\\ &\hat q_M(z_M)=\int_{\hat q_1} \int_{\hat q_2} \cdots \int_{\hat q_{M-1}}\hat q_1 \hat q_2 \cdots \hat q_{M-1} [\log P_{\theta}(x^{(i)}),z]d\hat q_1\hat q_2 \cdots d\hat q_{M-1}\\ \end{aligned} q^=argqminKL(qp)=argqmaxL(q)q^1(z1)=q2qMq2qM[logPθ(x(i)),z]dq2dqMq^2(z2)=q^1qMq^1qM[logPθ(x(i)),z]dq^1dqMq^M(zM)=q^1q^2q^M1q^1q^2q^M1[logPθ(x(i)),z]dq^1q^2dq^M1

类似于坐标上升梯度上升,收敛终止

Classical VI存在的问题:

  • 假设太强
  • intractable(依然要求积分)

4.SGVI

随机梯度变分推断

不再求 q ( z ) q(z) q(z)的具体值,假设 q ( z ) q(z) q(z)服从某种分布,求这个分布的参数 ϕ \phi ϕ

BELO
L ( ϕ ) = E q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) q ϕ ( z ) ] ϕ ^ = arg ⁡ max ⁡ L ( ϕ ) ∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] d z + ∫ q ϕ ( z ) ∇ ϕ [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ log ⁡ q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] d z − ∫ ∇ ϕ q ϕ ( z ) d z = E q ϕ ( z ) [ ∇ ϕ log ⁡ q ϕ ( z ) ( log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ) ] \begin{aligned} L(\phi)&=E_{q_{\phi}(z)} \left [ \log \frac{P_{\theta}(x^{(i)},z)}{q_{\phi}(z)}\right ]\\ \hat \phi &= \arg \max L(\phi)\\ \nabla_{\phi}L(\phi) &=\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]\\ &=\nabla_{\phi}\int q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int \nabla_{\phi} q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz \\ & \quad +\int q_{\phi}(z) \nabla_{\phi} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int q_{\phi}(z) \nabla_{\phi} \log q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz-\int \nabla_{\phi} q_{\phi}(z)dz\\ &=E_{q_{\phi}(z)} \left [ \nabla_{\phi} \log q_{\phi}(z) \left ( \log P_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right ) \right ] \end{aligned} L(ϕ)ϕ^ϕL(ϕ)=Eqϕ(z)[logqϕ(z)Pθ(x(i),z)]=argmaxL(ϕ)=ϕEqϕ(z)[logPθ(x(i),z)logqϕ(z)]=ϕqϕ(z)[logPθ(x(i),z)logqϕ(z)]dz=ϕqϕ(z)[logPθ(x(i),z)logqϕ(z)]dz+qϕ(z)ϕ[logPθ(x(i),z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)[logPθ(x(i),z)logqϕ(z)]dzϕqϕ(z)dz=Eqϕ(z)[ϕlogqϕ(z)(logPθ(x(i),z)logqϕ(z))]
因此可以用MC,从 q ϕ ( z ) q_{\phi}(z) qϕ(z)中采样,根据大数定理,用均值近似期望 E E E

z ( l ) ∼ q ϕ ( z ) , l = 1 , 2 , ⋯   , L z^{(l)} \sim q_{\phi}(z),l=1,2,\cdots,L z(l)qϕ(z),l=1,2,,L
≈ 1 L ∑ i = 1 L ∇ ϕ log ⁡ q ϕ ( z ( l ) ) log ⁡ P θ ( x ( i ) , z ( l ) − log ⁡ q ϕ ( z ( l ) ) ) \approx \frac{1}{L} \sum_{i=1}^{L} \nabla_{\phi} \log q_{\phi}(z^{(l)})\log P_{\theta}(x^{(i)},z^{(l)}-\log q_{\phi}(z^{(l)})) L1i=1Lϕlogqϕ(z(l))logPθ(x(i),z(l)logqϕ(z(l)))

存在的问题
  在于这部分 ∇ ϕ log ⁡ q ϕ ( z ) \nabla_{\phi} \log q_{\phi}(z) ϕlogqϕ(z),当采样到的值接近于0时,在对数log中变化很快(很敏感,方差很大),需要更多的样本,才能比较好的近似;
yon用期望近似 q ϕ ( z ) q_{\phi}(z) qϕ(z)的梯度,而我们的目标函数是 ϕ ^ \hat \phi ϕ^,因此误差是非常大的。

Reparameterization Trick 重参化技巧

∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ ( z ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] \nabla_{\phi}L(\phi) =\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] ϕL(ϕ)=ϕEqϕ(z)[logPθ(x(i),z)logqϕ(z)]
  期望是关于 q ϕ ( z ) q_{\phi}(z) qϕ(z) q ϕ ( z ) q_{\phi}(z) qϕ(z) ϕ \phi ϕ有关系,函数也和 ϕ \phi ϕ有关系,导致复杂度很高。为简化问题,假设 q ϕ ( z ) q_{\phi}(z) qϕ(z) ϕ \phi ϕ没有关系,用一个确定的分布 p ( ε ) p(\varepsilon) p(ε)替代 q ϕ ( z ) q_{\phi}(z) qϕ(z),就可以对直接对函数求导,不用对期望求导。 z ∼ p ϕ ( z ∣ x ) z \sim p_{\phi}(z \mid x) zpϕ(zx),引入重参化技巧把 z z z ϕ \phi ϕ的关系解耦。

假设 z = g ϕ ( ε , x ( i ) ) , ε ∼ p ( ε ) z=g_{\phi}(\varepsilon, x^{(i)}),\varepsilon \sim p(\varepsilon) z=gϕ(ε,x(i)),εp(ε) z z z ε \varepsilon ε为映射关系,各自的积分为1,有如下关系:
∣ p ϕ ( z ∣ x ( i ) ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | p_{\phi}(z \mid x^{(i)})dz \right | = \left | p(\varepsilon)d\varepsilon \right | pϕ(zx(i))dz=p(ε)dε
∇ ϕ L ( ϕ ) = ∇ ϕ ∫ [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] q ( ε ) d ε = ∇ ϕ E p ( ε ) [ log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ∣ x ( i ) ) ) ⋅ ∇ ϕ z ] = E p ( ε ) [ ∇ z ( log ⁡ P θ ( x ( i ) , z ) − log ⁡ q ϕ ( z ∣ x ( i ) ) ) ⋅ ∇ ϕ g ϕ ( ε , x ( i ) ) ] \begin{aligned} \nabla_{\phi}L(\phi) &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q_{\phi}(z) dz\\ &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q(\varepsilon) d\varepsilon\\ &=\nabla_{\phi} E_{p(\varepsilon)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] \\ &=E_{p(\varepsilon)} \left [ \nabla_{\phi} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ) \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi}z \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi} g_{\phi}(\varepsilon, x^{(i)}) \right ] \end{aligned} ϕL(ϕ)=ϕ[logPθ(x(i),z)logqϕ(z)]qϕ(z)dz=ϕ[logPθ(x(i),z)logqϕ(z)]q(ε)dε=ϕEp(ε)[logPθ(x(i),z)logqϕ(z)]=Ep(ε)[ϕ(logPθ(x(i),z)logqϕ(z))]=Ep(ε)[z(logPθ(x(i),z)logqϕ(zx(i)))ϕz]=Ep(ε)[z(logPθ(x(i),z)logqϕ(zx(i)))ϕgϕ(ε,x(i))]

你可能感兴趣的:(机器学习,机器学习)