GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导

GMM

高斯混合模型:
p ( x ) = ∑ z p ( c ) p ( x ∣ c ) = ∑ k = 0 K − 1 π k N ( x ∣ μ k , σ k ) \begin{aligned} p(x) = &\sum_{z}p(c)p(x|c) \\ = &\sum_{k=0}^{K-1} \pi_{k}N(x|\mu_{k}, \sigma_{k}) \end{aligned} p(x)==zp(c)p(xc)k=0K1πkN(xμk,σk)
其中 ∑ k π k = 1 , c 是 一 个 K 维 的 o n e − h o t 向 量 , p ( c k = 1 ) = π k \sum_{k} \pi_{k}=1, c是一个K维的one-hot向量,p(c_{k}=1)=\pi_{k} kπk=1cKonehotp(ck=1)=πk,这里的c其实服从类别分布,也就是 c ∼ C a t ( π ) c\sim Cat(\pi) cCat(π)

VAE回顾

变分自编码器的ELBO为:
E L B O = E q ϕ ( z ∣ x ) [ l o g p θ ( x ∣ z ) ] − K L [ q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ] (*) \tag{*} ELBO = E_{q_{\phi}(z|x)}[logp_{\theta}(x|z)]-KL[q_{\phi}(z|x)||p_{\theta}(z)] ELBO=Eqϕ(zx)[logpθ(xz)]KL[qϕ(zx)pθ(z)](*)
不熟悉的可以看变分自编码器回顾。在实际使用时,常常假设隐变量 z z z的变分后验 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)和先验 p θ ( z ) p_{\theta}(z) pθ(z)是高斯分布。但有的时候使用单个高斯分布可能使隐变量 z z z不能充分的学习到 x x x的特征,比如MINIST数据集有0~9这10个数字,直觉上使用10个高斯分布来替代单个高斯分布更为合理。所以就有学者将GMM和VAE结合起来,提出了GMVAE模型。

GMVAE

GMVAE的概率图模型如下图所示:
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第1张图片
其中, x x x是观测数据, c 和 z c和z cz分别是离散和连续的潜变量, c c c可以理解为指示模型选择哪个高斯的变量(比如第i个高斯), z z z可以理解为选择的第i个高斯产生的变量。实线是生成过程(decoer 过程),虚线是训练过程(encoder 过程),从生成过程中可以看出, x x x依赖于 z z z z z z依赖于 c c c

以minist数据集解释一下生成过程,假设我们已经训练好了GMVAE模型(意味着得到了10个高斯分布,第 i i i个高斯分布充分学习到了数字 i i i的特征)。

  1. 选择一个高斯分布, c ∼ C a t ( π ) c \sim Cat(\pi) cCat(π)
  2. 从1选择的高斯分布采样一个 z z z z ∼ N ( μ c , σ c 2 I ) z\sim N(\mu_c,\sigma_c^{2}I) zN(μcσc2I)
  3. 生成样本 x x x
    3.1 计算 x x x对应的均值 μ x \mu_{x} μx和方差 σ x 2 \sigma_{x}^{2} σx2
    [ μ x , l o g σ x 2 ] = f ( z ; θ ) [\mu_{x}, log\sigma_{x}^{2}] = f(z;\theta) [μx,logσx2]=f(z;θ)
    3.2 生成样本 x x x x ∼ N ( μ x , σ x 2 I ) x\sim N(\mu_{x}, \sigma_{x}^{2}I) xN(μx,σx2I)

其中, f f f是参数为 θ \theta θ的神经网络,输入时 z z z, N ( μ x , σ x 2 I ) N(\mu_{x}, \sigma_{x}^{2}I) N(μx,σx2I)时多元高斯分布。

熟悉变分自编码器的同学都知道, E L B O ELBO ELBO的推导其实是从观测数据 x x x的对数似然函数推导来的。
所以上图表示的 x x x的对数似然函数是:
l o g p ( x ) = ∫ z ∑ c p ( z , c ) p ( x ∣ z ) d z (1) \tag{1} logp(x) = \int_{z} \sum_{c}p(z,c)p(x|z)dz logp(x)=zcp(z,c)p(xz)dz(1)
其中, p ( z , c ) p(z,c) p(z,c)表示的就是高斯混合,同时也是先验分布。 p ( z , c ) p(z,c) p(z,c)也可以表示为 p ( z , c ) = p ( c ) p ( z ∣ c ) (2) \tag{2}p(z,c)=p(c)p(z|c) p(z,c)=p(c)p(zc)(2)
直觉上, p ( c ) p(c) p(c)就是类别分布, p ( z ∣ c ) p(z|c) p(zc)是一个多元高斯分布。

p θ ( z , c , x ) = p θ ( c ) p θ ( z ∣ c ) p θ ( x ∣ z ) p_{\theta}(z,c,x)=p_{\theta}(c)p_{\theta}(z|c)p_{\theta}(x|z) pθ(z,c,x)=pθ(c)pθ(zc)pθ(xz)表示 z , c , x z,c,x z,c,x的联合概率分布,参数为 θ \theta θ,则(1)式就是对 z , c z,c z,c进行了边缘化。

l o g p θ ( x ) = l o g ∫ z ∑ c p θ ( z , c , x ) d z = l o g ∫ z ∑ c p θ ( z , c , x ) q ϕ ( z , c ∣ x ) q ϕ ( z , c ∣ x ) d z = l o g ∫ z ∑ c p θ ( c , z ) p θ ( x ∣ z ) q ϕ ( z , c ∣ x ) q ϕ ( z , c ∣ x ) d z = l o g E q ϕ ( c , z ∣ x ) [ p θ ( c , z ) p θ ( x ∣ z ) q ϕ ( z , c ∣ x ) ] ≥ E q ϕ ( c , z ∣ x ) [ l o g p θ ( c , z ) p θ ( x ∣ z ) q ϕ ( z , c ∣ x ) ] \begin{aligned} logp_{\theta}(x)= & log\int_{z}\sum_{c}p_{\theta}(z,c,x)dz \\ = & log \int_{z}\sum_{c} {p_{\theta}(z,c,x)q_{\phi}(z,c|x) \over q_{\phi}(z,c|x)}dz \\ = & log \int_{z}\sum_{c} {p_{\theta}(c,z)p_{\theta}(x|z)q_{\phi}(z,c|x) \over q_{\phi}(z,c|x)}dz \\ = & logE_{q_{\phi}(c,z|x)}[ {p_{\theta}(c,z)p_{\theta}(x|z) \over q_{\phi}(z,c|x)}] \\ \ge & E_{q_{\phi}(c,z|x)}[log {p_{\theta}(c,z)p_{\theta}(x|z) \over q_{\phi}(z,c|x)}] \end{aligned} logpθ(x)====logzcpθ(z,c,x)dzlogzcqϕ(z,cx)pθ(z,c,x)qϕ(z,cx)dzlogzcqϕ(z,cx)pθ(c,z)pθ(xz)qϕ(z,cx)dzlogEqϕ(c,zx)[qϕ(z,cx)pθ(c,z)pθ(xz)]Eqϕ(c,zx)[logqϕ(z,cx)pθ(c,z)pθ(xz)]

E L B O = E q ϕ ( c , z ∣ x ) [ l o g p θ ( c , z ) p θ ( x ∣ z ) q ϕ ( z , c ∣ x ) ] = E q ϕ ( c , z ∣ x ) [ l o g p θ ( x ∣ z ) ] − K L [ q ϕ ( c , z ∣ x ) ∣ ∣ p θ ( c , z ) ] (3) \tag{3} \begin{aligned} ELBO = & E_{q_{\phi}(c,z|x)}[log {p_{\theta}(c,z)p_{\theta}(x|z) \over q_{\phi}(z,c|x)}] \\ =& E_{q_{\phi}(c,z|x)}[logp_{\theta}(x|z)] -KL[q_{\phi}(c,z|x)||p_{\theta}(c,z)] \end{aligned} ELBO==Eqϕ(c,zx)[logqϕ(z,cx)pθ(c,z)pθ(xz)]Eqϕ(c,zx)[logpθ(xz)]KL[qϕ(c,zx)pθ(c,z)](3)
(3)式中,第一项是重构误差,可第二项是变分后验 q ϕ ( c , z ∣ x ) q_{\phi}(c,z|x) qϕ(c,zx)与先验 p θ ( c , z ) p_{\theta}(c,z) pθ(c,z)的KL损失,也就是两个高斯混合的KL距离。在训练过程中,使用变分后验 q ϕ ( z , c ∣ x ) q_{\phi}(z,c|x) qϕ(z,cx)去近似真实的后验 p ( z , c ∣ x ) p_(z,c|x) p(z,cx),通过平均场理论,可以将 q ϕ ( z , c ∣ x ) q_{\phi}(z,c|x) qϕ(z,cx)分解为 q ϕ ( z , c ∣ x ) = q ϕ ( z ∣ x ) q ϕ ( c ∣ x ) (4) \tag{4}q_{\phi}(z,c|x)=q_{\phi}(z|x)q_{\phi}(c|x) qϕ(z,cx)=qϕ(zx)qϕ(cx)(4)

重构误差项可以通过SGVB估计算出,KL项计算比较复杂,下面计算KL项。将(2)式和(4)式带入KL项得:
− K L [ q ϕ ( c , z ∣ x ) ∣ ∣ p θ ( c , z ) ] = E q ϕ ( c , z ∣ x ) [ l o g p θ ( c , z ) q ϕ ( c , z ∣ x ) ] = E q ϕ ( c , z ∣ x ) [ l o g p θ ( c ) p θ ( z ∣ c ) q ϕ ( c ∣ x ) q ϕ ( z ∣ x ) ] = E q ϕ ( c , z ∣ x ) [ l o g p θ ( c ) + l o g p θ ( z ∣ c ) − l o g q ϕ ( c ∣ x ) − l o g q ϕ ( z ∣ x ) ] (5) \begin{aligned} \tag{5} -KL[q_{\phi}(c,z|x)||p_{\theta}(c,z)] = & E_{q_{\phi}(c,z|x)}[log{p_{\theta}(c,z) \over q_{\phi}(c,z|x) }] \\ = &E_{q_{\phi}(c,z|x)}[log{p_{\theta}(c)p_{\theta}(z|c) \over q_{\phi}(c|x)q_{\phi}(z|x)}] \\ = &E_{q_{\phi}(c,z|x)}[logp_{\theta}(c)+logp_{\theta}(z|c)-logq_{\phi}(c|x)-logq_{\phi}(z|x)] \end{aligned} KL[qϕ(c,zx)pθ(c,z)]===Eqϕ(c,zx)[logqϕ(c,zx)pθ(c,z)]Eqϕ(c,zx)[logqϕ(cx)qϕ(zx)pθ(c)pθ(zc)]Eqϕ(c,zx)[logpθ(c)+logpθ(zc)logqϕ(cx)logqϕ(zx)](5)
(5)式中,c表示选择第几个高斯,是一个编号,而z和x是向量, p θ ( c ) 是 c 的 先 验 分 布 , 一 般 初 始 化 为 1 K , q ϕ ( c ∣ x ) 是 c 的 后 验 分 布 , p θ ( z ∣ c ) 是 z 的 先 验 分 布 , 是 一 个 高 斯 分 布 , q ϕ ( z ∣ x ) 是 z 的 后 验 分 布 , 也 是 一 个 高 斯 分 布 p_{\theta}(c)是c的先验分布,一般初始化为{1 \over K},q_{\phi}(c|x)是c的后验分布,p_{\theta}(z|c)是z的先验分布,是一个高斯分布,q_{\phi}(z|x)是z的后验分布,也是一个高斯分布 pθ(c)cK1qϕ(cx)cpθ(zc)zqϕ(zx)z
在论文Variational Deep Embedding: An Unsupervised and Generative Approach to Clustering中,将 q ϕ ( c ∣ x ) q_{\phi}(c|x) qϕ(cx)做了如下变化。
q ϕ ( c ∣ x ) = E q ( z ∣ x ) [ p ( c ∣ z ) ] p ( c ∣ z ) = p ( c ) p ( z ∣ c ) ∑ c ′ = 1 K p ( c ′ ) p ( z ∣ c ′ ) \begin{aligned} q_{\phi}(c|x)= & E_{q(z|x)}[p(c|z)] \\ p(c|z) = & {p(c)p(z|c) \over \sum_{c'=1}^{K}p(c')p(z|c')} \end{aligned} qϕ(cx)=p(cz)=Eq(zx)[p(cz)]c=1Kp(c)p(zc)p(c)p(zc)
下面分别求(5)中的每一项:
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第2张图片
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第3张图片
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第4张图片
上面的积分号里是两个高斯分布的交叉熵,所以最后是下式:
在这里插入图片描述
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第5张图片
上面的积分号里也是两个高斯分布的交叉熵。
在这里插入图片描述
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第6张图片
GMVAE(GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS)高斯混合变分自编码器公式推导_第7张图片
这里也用到了SGVB的方法。

至此,(3)式的ELBO的每一项都可以求解。

你可能感兴趣的:(机器学习,NLP,变分贝叶斯系列)