{ 频 率 角 度 , 优 化 问 题 { 回 归 { M o d e l 策 略 算 法 { 解 析 解 数 值 解 S V M E M 等 等 贝 叶 斯 角 度 , 积 分 问 题 { 贝 叶 斯 I n f e r n e c e ( 求 贝 叶 斯 后 验 ) P ( θ ∣ x ) = P ( x ∣ θ ) P ( θ ) P ( x ) 贝 叶 斯 决 策 ( 预 测 , 最 后 还 是 求 贝 叶 斯 后 验 ) P ( x ~ ∣ x ) = ∫ θ P ( x ~ , θ ∣ x ) d θ = ∫ θ P ( x ~ ∣ θ ) P ( θ ∣ x ) d θ = E θ ∣ x [ P ( x ~ ∣ θ ) ] \begin{cases} 频率角度,优化问题 \begin{cases} 回归 \begin{cases} Model\\ 策略\\ 算法 \begin{cases} 解析解\\ 数值解 \end{cases} \end{cases}\\ SVM\\ EM\\ 等等 \end{cases}\\ 贝叶斯角度,积分问题 \begin{cases} 贝叶斯Infernece(求贝叶斯后验)\\ P(\theta \mid x)=\frac{P(x \mid \theta)P(\theta)}{P(x)}\\ 贝叶斯决策(预测,最后还是求贝叶斯后验)\\ P(\tilde{x} \mid x)=\int_{\theta}P(\tilde{x},\theta \mid x) d\theta=\int_{\theta}P(\tilde{x} \mid \theta)P(\theta \mid x)d\theta=E_{\theta \mid x}[P(\tilde{x} \mid \theta)] \end{cases} \end{cases} ⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧频率角度,优化问题⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧回归⎩⎪⎪⎪⎨⎪⎪⎪⎧Model策略算法{ 解析解数值解SVMEM等等贝叶斯角度,积分问题⎩⎪⎪⎪⎨⎪⎪⎪⎧贝叶斯Infernece(求贝叶斯后验)P(θ∣x)=P(x)P(x∣θ)P(θ)贝叶斯决策(预测,最后还是求贝叶斯后验)P(x~∣x)=∫θP(x~,θ∣x)dθ=∫θP(x~∣θ)P(θ∣x)dθ=Eθ∣x[P(x~∣θ)]
I n f e r e n c e { 精 确 推 断 ( 后 验 简 单 ) 近 似 推 断 / 近 似 推 断 的 期 望 ( 参 数 空 间 、 隐 变 量 非 常 复 杂 ) { 确 定 性 近 似 → V I 随 机 近 似 → M C M C , M H , G i b b s Inference \begin{cases} 精确推断(后验简单)\\ 近似推断/近似推断的期望(参数空间、隐变量非常复杂)\\ \begin{cases} 确定性近似\to VI\\ 随机近似 \to MCMC,MH,Gibbs \end{cases} \end{cases} Inference⎩⎪⎪⎪⎨⎪⎪⎪⎧精确推断(后验简单)近似推断/近似推断的期望(参数空间、隐变量非常复杂){ 确定性近似→VI随机近似→MCMC,MH,Gibbs
x x x:observed data
z z z:later variable + parameter
( x , z ) (x,z) (x,z):complete data
ELBO + KL
log P ( x ) = L ( q ) + K L ( q ∣ ∣ p ) \log P(x)= L(q)+KL(q||p) logP(x)=L(q)+KL(q∣∣p)
q ^ ( z ) = arg max q ( z ) L ( q ) → q ^ ( z ) ≈ p ( z ∣ x ) \hat q(z)=\arg \max_{q(z)} L(q) \to \hat q(z) \approx p(z \mid x) q^(z)=argmaxq(z)L(q)→q^(z)≈p(z∣x)
基于物理的平均场理论
q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^M q_i(z_i) q(z)=∏i=1Mqi(zi),计算时固定一维 q j ( z j ) q_j(z_j) qj(zj)
L ( q ) = ∫ z q ( z ) log P ( x , z ) d z − ∫ z q ( z ) log q ( z ) d z L(q)=\int_z q(z) \log P(x,z)dz-\int_z q(z)\log q(z)dz L(q)=∫zq(z)logP(x,z)dz−∫zq(z)logq(z)dz
∫ z q ( z ) log P ( x , z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log P ( x , z ) d z = ∫ z j q j ( z j ) d z j ( ∫ z i ∏ i M q i ( z i ) log P ( x , z ) d z i ) ( i ≠ j ) = ∫ z j q j ( z j ) d z j E ∏ i M q i ( z i ) [ log P ( x , z ) ] ∫ z q ( z ) log q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log ∏ i = 1 M q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M log q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ log q 1 ( z 1 ) + log q 2 ( z 2 ) + ⋯ + log q M ( z M ) ] d z = ∑ i = 1 M ∫ z i q i ( z i ) log q i ( z i ) d z i = ∫ z j q j ( z j ) log q j ( z j ) d z j + C L ( q ) = ∫ z j q j ( z j ) log p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ∣ ∣ p ^ ( x , z j ) ) ≤ 0 \begin{aligned} \int_z q(z) \log P(x,z)dz &=\int_z \prod_{i=1}^M q_i(z_i) \log P(x,z)dz\\ &=\int_{z_j} q_j(z_j) dz_j\left ( \int_{z_i} \prod_{i}^M q_i(z_i) \log P(x,z) dz_i \right )(i \ne j)\\ &=\int_{z_j} q_j(z_j) dz_j E_{\prod_{i}^{M} q_i(z_i)}[\log P(x,z)]\\ \int_z q(z) \log q(z) dz &=\int_z \prod_{i=1}^{M}q_i(z_i) \log \prod_{i=1}^{M}q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) \sum_{i=1}^{M} \log q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) [ \log q_1(z_1)+\log q_2(z_2)+\cdots+\log q_M(z_M) ] dz\\ &=\sum_{i=1}^{M} \int_{z_i}q_i(z_i) \log q_i(z_i)dz_i\\ &=\int_{z_j}q_j(z_j) \log q_j(z_j)dz_j + C\\ L(q)&=\int_{z_j} q_j(z_j)\log \frac{\hat p(x,z_j)}{q_j(z_j)}dz_j\\ &=-KL(q_j||\hat p(x,z_j)) \le0 \end{aligned} ∫zq(z)logP(x,z)dz∫zq(z)logq(z)dzL(q)=∫zi=1∏Mqi(zi)logP(x,z)dz=∫zjqj(zj)dzj(∫zii∏Mqi(zi)logP(x,z)dzi)(i=j)=∫zjqj(zj)dzjE∏iMqi(zi)[logP(x,z)]=∫zi=1∏Mqi(zi)logi=1∏Mqi(zi)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)logqj(zj)p^(x,zj)dzj=−KL(qj∣∣p^(x,zj))≤0
用EM算法求解含隐变量的极大似然估计,极大似然估计是关于后验概率的函数,将不等式划等号q=p,一般p很复杂不易求得,最小化KL可以得到最优解 q ^ \hat q q^,基于平均场理论,使用相互独立的 ∏ i = 1 M q i ( z i ) \prod_{i=1}^{M}q_i(z_i) ∏i=1Mqi(zi)近似推断后验p
VI (mean field) → \to →Classical VI
Assumption: q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_i(z_i) q(z)=∏i=1Mqi(zi)
log q j ( z j ) = E ∏ i q i ( z i ) [ log P θ ( x ( i ) , z ) ] + C = ∫ q 1 ⋯ ∫ q j − 1 ∫ q j + 1 ⋯ ∫ q M q 1 ⋯ q j − 1 q j + 1 ⋯ q M [ log P θ ( x ( i ) ) , z ] d q 1 ⋯ d q j − 1 d q j + 1 ⋯ d q M \begin{aligned} \log q_j(z_j) &=E_{\prod_i q_i(z_i)}[\log P_{\theta}(x^{(i)},z)]+C\\ &=\int_{q_1}\cdots\int_{q_{j-1}} \int_{q_{j+1}} \cdots \int_{q_M}q_1 \cdots q_{j-1}q_{j+1} \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_1\cdots dq_{j-1} dq_{j+1} \cdots dq_M \end{aligned} logqj(zj)=E∏iqi(zi)[logPθ(x(i),z)]+C=∫q1⋯∫qj−1∫qj+1⋯∫qMq1⋯qj−1qj+1⋯qM[logPθ(x(i)),z]dq1⋯dqj−1dqj+1⋯dqM
目标函数:
q ^ = arg min q K L ( q ∣ ∣ p ) = arg max q L ( q ) q ^ 1 ( z 1 ) = ∫ q 2 ⋯ ∫ q M q 2 ⋯ q M [ log P θ ( x ( i ) ) , z ] d q 2 ⋯ d q M q ^ 2 ( z 2 ) = ∫ q ^ 1 ⋯ ∫ q M q ^ 1 ⋯ q M [ log P θ ( x ( i ) ) , z ] d q ^ 1 ⋯ d q M q ^ M ( z M ) = ∫ q ^ 1 ∫ q ^ 2 ⋯ ∫ q ^ M − 1 q ^ 1 q ^ 2 ⋯ q ^ M − 1 [ log P θ ( x ( i ) ) , z ] d q ^ 1 q ^ 2 ⋯ d q ^ M − 1 \begin{aligned} &\hat q = \arg \min_q KL(q||p)=\arg \max_q L(q)\\ &\hat q_1(z_1) =\int_{q_2} \cdots \int_{q_M}q_2 \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_2 \cdots dq_M\\ &\hat q_2(z_2) =\int_{\hat q_1} \cdots \int_{q_M}\hat q_1 \cdots q_M [\log P_{\theta}(x^{(i)}),z]d\hat q_1 \cdots dq_M\\ &\hat q_M(z_M)=\int_{\hat q_1} \int_{\hat q_2} \cdots \int_{\hat q_{M-1}}\hat q_1 \hat q_2 \cdots \hat q_{M-1} [\log P_{\theta}(x^{(i)}),z]d\hat q_1\hat q_2 \cdots d\hat q_{M-1}\\ \end{aligned} q^=argqminKL(q∣∣p)=argqmaxL(q)q^1(z1)=∫q2⋯∫qMq2⋯qM[logPθ(x(i)),z]dq2⋯dqMq^2(z2)=∫q^1⋯∫qMq^1⋯qM[logPθ(x(i)),z]dq^1⋯dqMq^M(zM)=∫q^1∫q^2⋯∫q^M−1q^1q^2⋯q^M−1[logPθ(x(i)),z]dq^1q^2⋯dq^M−1
类似于坐标上升梯度上升,收敛终止
Classical VI存在的问题:
随机梯度变分推断
不再求 q ( z ) q(z) q(z)的具体值,假设 q ( z ) q(z) q(z)服从某种分布,求这个分布的参数 ϕ \phi ϕ
BELO
L ( ϕ ) = E q ϕ ( z ) [ log P θ ( x ( i ) , z ) q ϕ ( z ) ] ϕ ^ = arg max L ( ϕ ) ∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ ( z ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] d z + ∫ q ϕ ( z ) ∇ ϕ [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ log q ϕ ( z ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] d z − ∫ ∇ ϕ q ϕ ( z ) d z = E q ϕ ( z ) [ ∇ ϕ log q ϕ ( z ) ( log P θ ( x ( i ) , z ) − log q ϕ ( z ) ) ] \begin{aligned} L(\phi)&=E_{q_{\phi}(z)} \left [ \log \frac{P_{\theta}(x^{(i)},z)}{q_{\phi}(z)}\right ]\\ \hat \phi &= \arg \max L(\phi)\\ \nabla_{\phi}L(\phi) &=\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]\\ &=\nabla_{\phi}\int q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int \nabla_{\phi} q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz \\ & \quad +\int q_{\phi}(z) \nabla_{\phi} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int q_{\phi}(z) \nabla_{\phi} \log q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz-\int \nabla_{\phi} q_{\phi}(z)dz\\ &=E_{q_{\phi}(z)} \left [ \nabla_{\phi} \log q_{\phi}(z) \left ( \log P_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right ) \right ] \end{aligned} L(ϕ)ϕ^∇ϕL(ϕ)=Eqϕ(z)[logqϕ(z)Pθ(x(i),z)]=argmaxL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz=∫∇ϕqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz+∫qϕ(z)∇ϕ[logPθ(x(i),z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz−∫∇ϕqϕ(z)dz=Eqϕ(z)[∇ϕlogqϕ(z)(logPθ(x(i),z)−logqϕ(z))]
因此可以用MC,从 q ϕ ( z ) q_{\phi}(z) qϕ(z)中采样,根据大数定理,用均值近似期望 E E E
z ( l ) ∼ q ϕ ( z ) , l = 1 , 2 , ⋯ , L z^{(l)} \sim q_{\phi}(z),l=1,2,\cdots,L z(l)∼qϕ(z),l=1,2,⋯,L
≈ 1 L ∑ i = 1 L ∇ ϕ log q ϕ ( z ( l ) ) log P θ ( x ( i ) , z ( l ) − log q ϕ ( z ( l ) ) ) \approx \frac{1}{L} \sum_{i=1}^{L} \nabla_{\phi} \log q_{\phi}(z^{(l)})\log P_{\theta}(x^{(i)},z^{(l)}-\log q_{\phi}(z^{(l)})) ≈L1i=1∑L∇ϕlogqϕ(z(l))logPθ(x(i),z(l)−logqϕ(z(l)))
存在的问题:
在于这部分 ∇ ϕ log q ϕ ( z ) \nabla_{\phi} \log q_{\phi}(z) ∇ϕlogqϕ(z),当采样到的值接近于0时,在对数log中变化很快(很敏感,方差很大),需要更多的样本,才能比较好的近似;
yon用期望近似 q ϕ ( z ) q_{\phi}(z) qϕ(z)的梯度,而我们的目标函数是 ϕ ^ \hat \phi ϕ^,因此误差是非常大的。
Reparameterization Trick 重参化技巧
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ ( z ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] \nabla_{\phi}L(\phi) =\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] ∇ϕL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]
期望是关于 q ϕ ( z ) q_{\phi}(z) qϕ(z), q ϕ ( z ) q_{\phi}(z) qϕ(z)和 ϕ \phi ϕ有关系,函数也和 ϕ \phi ϕ有关系,导致复杂度很高。为简化问题,假设 q ϕ ( z ) q_{\phi}(z) qϕ(z)和 ϕ \phi ϕ没有关系,用一个确定的分布 p ( ε ) p(\varepsilon) p(ε)替代 q ϕ ( z ) q_{\phi}(z) qϕ(z),就可以对直接对函数求导,不用对期望求导。 z ∼ p ϕ ( z ∣ x ) z \sim p_{\phi}(z \mid x) z∼pϕ(z∣x),引入重参化技巧把 z z z和 ϕ \phi ϕ的关系解耦。
假设 z = g ϕ ( ε , x ( i ) ) , ε ∼ p ( ε ) z=g_{\phi}(\varepsilon, x^{(i)}),\varepsilon \sim p(\varepsilon) z=gϕ(ε,x(i)),ε∼p(ε), z z z和 ε \varepsilon ε为映射关系,各自的积分为1,有如下关系:
∣ p ϕ ( z ∣ x ( i ) ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | p_{\phi}(z \mid x^{(i)})dz \right | = \left | p(\varepsilon)d\varepsilon \right | ∣∣∣pϕ(z∣x(i))dz∣∣∣=∣p(ε)dε∣
∇ ϕ L ( ϕ ) = ∇ ϕ ∫ [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] q ( ε ) d ε = ∇ ϕ E p ( ε ) [ log P θ ( x ( i ) , z ) − log q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( log P θ ( x ( i ) , z ) − log q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( log P θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ⋅ ∇ ϕ z ] = E p ( ε ) [ ∇ z ( log P θ ( x ( i ) , z ) − log q ϕ ( z ∣ x ( i ) ) ) ⋅ ∇ ϕ g ϕ ( ε , x ( i ) ) ] \begin{aligned} \nabla_{\phi}L(\phi) &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q_{\phi}(z) dz\\ &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q(\varepsilon) d\varepsilon\\ &=\nabla_{\phi} E_{p(\varepsilon)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] \\ &=E_{p(\varepsilon)} \left [ \nabla_{\phi} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ) \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi}z \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi} g_{\phi}(\varepsilon, x^{(i)}) \right ] \end{aligned} ∇ϕL(ϕ)=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]q(ε)dε=∇ϕEp(ε)[logPθ(x(i),z)−logqϕ(z)]=Ep(ε)[∇ϕ(logPθ(x(i),z)−logqϕ(z))]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕz]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕgϕ(ε,x(i))]