我将看过的论文建了一个github库,方便各位阅读地址
传统的VAE,隐变量服从标准高斯分布(单峰),但有时候,单个高斯分布可能不能完全表达图像x的特征,比如MINIST数据集有0~9这10个数字,直觉上使用10个高斯分布来替代单个高斯分布更为合理,因此有学者将混合高斯分布模型(GMM)与VAE进行结合,其结果便是GMVAE。
FBI warning
本文为代码与论文结合进行理解的产物,如有错误,欢迎指出。本文不会进行ELBO的推导,将直接从论文给出的ELBO算式进行讲解。
损失函数由变分推断推导而来,由于论文遗漏了太多推导细节,本文将不会介绍这部分推导,将重点介绍损失函数的各个部分如何计算。
与VAE一样,GMVAE通过最大化ELBO来进行优化,ELBO的形式如下:
L E L B O = E q ( x ∣ y ) [ p θ ( y ∣ x ) ] − E q ( w ∣ y ) p ( z ∣ x , w ) [ K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z ) ) ] − K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) − E q ( x ∣ y ) q ( w ∣ y ) [ K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) ] \begin{aligned} L_{ELBO}=&E_{q(x|y)}[p_\theta(y|x)]-E_{q(w|y)p(z|x,w)}[KL(q_{\phi_x}(x|y)||p_{\beta}(x|w,z))]\\ &-KL(q_{\phi_x}(w|y)||p(w))-E_{q_(x|y)q(w|y)}[KL(p_\beta(z|x,w)||p(z))] \end{aligned} LELBO=Eq(x∣y)[pθ(y∣x)]−Eq(w∣y)p(z∣x,w)[KL(qϕx(x∣y)∣∣pβ(x∣w,z))]−KL(qϕx(w∣y)∣∣p(w))−Eq(x∣y)q(w∣y)[KL(pβ(z∣x,w)∣∣p(z))]
ϕ x 、 θ 、 β \phi_x、\theta、\beta ϕx、θ、β表示待优化的参数,可以暂时忽视。
接下来我将介绍每一部分的计算方式
E q ( x ∣ y ) [ p θ ( y ∣ x ) ] E_{q(x|y)}[p_\theta(y|x)] Eq(x∣y)[pθ(y∣x)]表示重构误差,由于我们假定 p θ ( y ∣ x ) p_\theta(y|x) pθ(y∣x)服从高斯分布,所以与VAE一样,可以用均方误差进行计算。
对 E q ( w ∣ y ) p ( z ∣ x , w ) [ K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z ) ) ] E_{q(w|y)p(z|x,w)}[KL(q_{\phi_x}(x|y)||p_{\beta}(x|w,z))] Eq(w∣y)p(z∣x,w)[KL(qϕx(x∣y)∣∣pβ(x∣w,z))]使用蒙特卡洛模拟,可得
1 M ∑ j = 1 M ∑ k = 1 K p β ( z k = 1 ∣ x ( j ) , w ( j ) ) K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w ( j ) , z k = 1 ) ) (1.0) \frac{1}{M}\sum_{j=1}^M\sum_{k=1}^Kp_{\beta}(z_k=1|x^{(j)},w^{(j)})KL(q_{\phi_x}(x|y)||p_\beta(x|w^{(j)},z_k=1))\tag{1.0} M1j=1∑Mk=1∑Kpβ(zk=1∣x(j),w(j))KL(qϕx(x∣y)∣∣pβ(x∣w(j),zk=1))(1.0)
M M M采样的样本数,我们可以将其设置为1,则1.0可变化为
∑ k = 1 K p β ( z k = 1 ∣ x , w ) K L ( q ϕ x ( x ∣ y ) ∣ ∣ p β ( x ∣ w , z k = 1 ) ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) E q ϕ x ( x ∣ y ) [ log q ϕ x ( x ∣ y ) p β ( x ∣ w , z k = 1 ) ] = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) ∑ i = 1 N log q ϕ x ( x i ∣ y ) p β ( x i ∣ w , z k = 1 ) (2.0) \begin{aligned} &\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)KL(q_{\phi_x}(x|y)||p_\beta(x|w,z_k=1))\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)E_{q_{\phi_x}(x|y)}[\log\frac{q_{\phi_x}(x|y)}{p_\beta(x|w,z_k=1)}]\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\sum_{i=1}^N\log\frac{q_{\phi_x}(x_i|y)}{p_\beta(x_i|w,z_k=1)} \end{aligned}\tag{2.0} ==k=1∑Kpβ(zk=1∣x,w)KL(qϕx(x∣y)∣∣pβ(x∣w,zk=1))k=1∑Kpβ(zk=1∣x,w)Eqϕx(x∣y)[logpβ(x∣w,zk=1)qϕx(x∣y)]k=1∑Kpβ(zk=1∣x,w)i=1∑Nlogpβ(xi∣w,zk=1)qϕx(xi∣y)(2.0)
第三行式子利用蒙特卡洛模拟得到,同理,将N设置为1,式2.0可变为
∑ k = 1 K p β ( z k = 1 ∣ x , w ) log q ϕ x ( x ∣ y ) p β ( x ∣ w , z k = 1 ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) log q ϕ x ( x ∣ y ) − ∑ k = 1 K p β ( z k = 1 ∣ x , w ) log p β ( x ∣ w , z k = 1 ) = log q ϕ x ( x ∣ y ) − ∑ k = 1 K p β ( z k = 1 ∣ x , w ) log p β ( x ∣ w , z k = 1 ) (3.0) \begin{aligned} &\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log\frac{q_{\phi_x}(x|y)}{p_\beta(x|w,z_k=1)}\\ =&\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log q_{\phi_x}(x|y)-\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log p_\beta(x|w,z_k=1)\\ =&\log q_{\phi_x}(x|y)-\sum_{k=1}^Kp_{\beta}(z_k=1|x,w)\log p_\beta(x|w,z_k=1)\tag{3.0} \end{aligned} ==k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)qϕx(x∣y)k=1∑Kpβ(zk=1∣x,w)logqϕx(x∣y)−k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)logqϕx(x∣y)−k=1∑Kpβ(zk=1∣x,w)logpβ(x∣w,zk=1)(3.0)
K K K为混合高斯分布中高斯分布的个数,我们有如下假设:
则有
log p β ( x ∣ w , z k = 1 ) = log 1 2 π δ k β e − ( x − μ k β ) 2 2 ( δ k β ) 2 = log 1 2 π − log δ k β − ( x − μ k β ) 2 2 ( δ k β ) 2 (4.0) \begin{aligned} \log p_\beta(x|w,z_k=1)&=\log\frac{1}{\sqrt {2\pi}\delta^\beta_k}e^{-\frac{(x-\mu^\beta_k)^2}{2(\delta^\beta_k)^2}}\\ &=\log \frac{1}{\sqrt{2\pi}}-\log \delta_k^\beta-\frac{(x-\mu^\beta_k)^2}{2(\delta^\beta_k)^2}\tag{4.0} \end{aligned} logpβ(x∣w,zk=1)=log2πδkβ1e−2(δkβ)2(x−μkβ)2=log2π1−logδkβ−2(δkβ)2(x−μkβ)2(4.0)
log q ϕ x ( x ∣ y ) = log 1 2 π δ ϕ x e − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 = log 1 2 π δ ϕ x e − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 = log 1 2 π − log δ ϕ x − ( x − μ ϕ x ) 2 2 ( δ ϕ x ) 2 (5.0) \begin{aligned} \log q_{\phi_x}(x|y)&=\log \frac{1}{\sqrt {2\pi}\delta^{\phi_x}}e^{-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2}}\\ &=\log \frac{1}{\sqrt {2\pi}\delta^{\phi_x}}e^{-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2}}\\ &=\log \frac{1}{\sqrt{2\pi}}-\log \delta^{\phi_x}-\frac{(x-\mu^{\phi_x})^2}{2(\delta^{\phi_x})^2} \end{aligned}\tag{5.0} logqϕx(x∣y)=log2πδϕx1e−2(δϕx)2(x−μϕx)2=log2πδϕx1e−2(δϕx)2(x−μϕx)2=log2π1−logδϕx−2(δϕx)2(x−μϕx)2(5.0)
x x x是服从 q ϕ x ( x ∣ y ) q_{\phi_x}(x|y) qϕx(x∣y)分布的样本,可以通过VAE提出的reparameterization trick得到
K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) KL(q_{\phi_x}(w|y)||p(w)) KL(qϕx(w∣y)∣∣p(w))有如下假设
则有
K L ( q ϕ x ( w ∣ y ) ∣ ∣ p ( w ) ) = 1 2 ∑ i = 1 n ( ( μ i ϕ w ) 2 + ( δ i ϕ w ) 2 − 1 − log ( δ i ϕ w ) 2 ) (6.0) \begin{aligned} KL(q_{\phi_x}(w|y)||p(w))=\frac{1}{2}\sum_{i=1}^n((\mu_i^{\phi_w})^2+(\delta_i^{\phi_w})^2-1-\log (\delta_i^{\phi_w})^2) \end{aligned}\tag{6.0} KL(qϕx(w∣y)∣∣p(w))=21i=1∑n((μiϕw)2+(δiϕw)2−1−log(δiϕw)2)(6.0)
同理,对 E q ( x ∣ y ) q ( w ∣ y ) [ K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) ] E_{q_(x|y)q(w|y)}[KL(p_\beta(z|x,w)||p(z))] Eq(x∣y)q(w∣y)[KL(pβ(z∣x,w)∣∣p(z))]使用蒙特卡洛模拟,可得
1 M ∑ i = 1 M K L ( p β ( z ∣ x i , w i ) ∣ ∣ p ( z ) ) \frac{1}{M}\sum_{i=1}^MKL(p_\beta(z|x_i,w_i)||p(z)) M1i=1∑MKL(pβ(z∣xi,wi)∣∣p(z))
我们有如下假设
将M设置为1,则有
K L ( p β ( z ∣ x , w ) ∣ ∣ p ( z ) ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) log p β ( z k = 1 ∣ x , w ) p ( z k = 1 ) = ∑ k = 1 K p β ( z k = 1 ∣ x , w ) [ log p β ( z k = 1 ∣ x , w ) + log K ] \begin{aligned} KL(p_\beta(z|x,w)||p(z))&=\sum_{k=1}^Kp_\beta(z_k=1|x,w)\log \frac{p_\beta(z_k=1|x,w)}{p(z_k=1)}\\ &=\sum_{k=1}^Kp_\beta(z_k=1|x,w)[\log p_\beta(z_k=1|x,w)+\log K] \end{aligned} KL(pβ(z∣x,w)∣∣p(z))=k=1∑Kpβ(zk=1∣x,w)logp(zk=1)pβ(zk=1∣x,w)=k=1∑Kpβ(zk=1∣x,w)[logpβ(zk=1∣x,w)+logK]
本节结构为博主阅读代码后所得,博主没有复现GMVAE,故仅供参考。
代码地址
图像生成的结构如下
如果您想了解更多有关深度学习、机器学习基础知识,或是java开发、大数据相关的知识,欢迎关注我们的公众号,我将在公众号上不定期更新深度学习、机器学习相关的基础知识,分享深度学习中有趣文章的阅读笔记。