VAE详解及PyTorch代码

三大有名的生成模型VAE、GAN以及Diffusion Model

其余两篇

看了网上的一些博客,大多都写到了重点,也就是后面的公式推导部分,可是大部分只有推导过程,很少有讲明白为什么要这么假设,我看的时候内心不断有个疑问:这些所有推导的第一个式子是怎么来的?为什么所有的推导都是要证明第一个式子?下面我们从生成模型的源头来理解这个问题,就茅塞顿开了

什么是生成模型?

首先要明白什么是生成模型?比如我们有一堆数据点 X X X,他的真实分布是 P g t ( X ) P_{gt}(X) Pgt(X),生成模型的目的就是去学习一个模型 M M M,将一些随机采样的噪声(通常为高斯噪声)输入到此模型中,使得此模型的输出为 X X X中的数据,即模型 M M M的分布 P P P去尽可能的接近数据的真实分布 P g t ( X ) P_{gt}(X) Pgt(X),或者说让模型 M M M能够尽可能地生成数据点 X X X中的数据

VAE详解及PyTorch代码_第1张图片

生成模型要做什么

如上所述,生成模型就是要去接近数据点(数据集) X X X的真实分布,也就是说我们要最大化所建模的概率分布

P ( X ) = ∫ P ( X ∣ z ; θ ) P ( z ) d z = ∫ P ( X ∣ z ) P ( z ) d z P(X) = \int P(X|z;\theta)P(z)dz = \int P(X|z)P(z)dz P(X)=P(Xz;θ)P(z)dz=P(Xz)P(z)dz

这里的 θ \theta θ就是模型的参数, z z z就是随机采样的噪声, P ( X ∣ z ; θ ) P(X|z;\theta) P(Xz;θ) X X X的后验概率分布,通常来说 P ( X ∣ z ; θ ) P(X|z;\theta) P(Xz;θ) P ( z ) P(z) P(z)都是高斯分布

VAE的目标函数

实际上,对于大多数的 z z z P ( X ∣ z ) P(X|z) P(Xz)的取值都接近0,因此VAE的核心就是不断地进行采样,直到可能产生数据点 X X X为止,这显然会让生成过程变得困难/复杂

在真正实现时,作者构建了一个新的函数 Q ( z ∣ X ) Q(z|X) Q(zX),给定数据集中的一个数据 X X X Q ( z ∣ X ) Q(z|X) Q(zX)可以给我们一个 z z z的分布,这个给定的分布就更加容易产生出 X X X了,在 Q ( z ∣ X ) Q(z|X) Q(zX)后的 z z z的范围比先验 P ( z ) P(z) P(z)要小得多,因此我们将 E z ∼ Q P ( X ∣ z ) E_{z\sim Q}P(X|z) EzQP(Xz) P ( X ) P(X) P(X)关联起来,如下
E z ∼ Q P ( X ∣ z ) = ∫ P ( X ∣ z ) Q ( z ) d z E_{z\sim Q}P(X|z) = \int P(X|z)Q(z)dz EzQP(Xz)=P(Xz)Q(z)dz
这个式子是不是和上面那个很相似?只是讲 P ( z ) P(z) P(z)替换为了 Q ( z ) Q(z) Q(z),这种替换能够使得我们的随机采样变得不是那么的随机(即,相对于 P ( z ) P(z) P(z) Q ( z ) Q(z) Q(z)更加具体了采样的范围,而不是无脑采样…)

在原文中,也有提到 E z ∼ Q P ( X ∣ z ) E_{z\sim Q}P(X|z) EzQP(Xz)是求解 P ( X ) P(X) P(X)的关键: The relationship between E z ∼ Q P ( X ∣ z ) E_{z\sim Q}P(X|z) EzQP(Xz) and P ( X ) P(X) P(X) is one of the cornerstones of variational Bayesian methods.

既然上面提到了 Q ( z ) Q(z) Q(z)的获取对生成 X X X至关重要,那如何去获得呢,作者用KL散度来构建 Q ( z ) Q(z) Q(z) P ( z ∣ X ) P(z|X) P(zX)之间的关系(在文中KL散度用符号 D \mathcal{D} D来表示)
D [ Q ( z ) ∣ ∣ P ( z ∣ X ) ] = E z ∼ Q [ log ⁡ Q ( z ) − log ⁡ P ( z ∣ X ) ] \mathcal{D}[Q(z)||P(z|X)] = E_{z\sim Q}[\log Q(z) - \log P(z|X)] D[Q(z)∣∣P(zX)]=EzQ[logQ(z)logP(zX)]

下面用Bayes公式来化简上述式子,将 P ( z ∣ X ) = P ( z ) P ( X ∣ z ) P ( X ) P(z|X)=\frac{P(z)P(X|z)}{P(X)} P(zX)=P(X)P(z)P(Xz)代替
D [ Q ( z ) ∣ ∣ P ( z ∣ X ) ] = E z ∼ Q [ log ⁡ Q ( z ) − log ⁡ P ( z ) P ( X ∣ z ) P ( X ) ] = E z ∼ Q [ log ⁡ Q ( z ) − log ⁡ P ( z ) − log ⁡ P ( X ∣ z ) + log ⁡ P ( X ) ] = E z ∼ Q [ log ⁡ Q ( z ) − log ⁡ P ( z ) − log ⁡ P ( X ∣ z ) ] + log ⁡ P ( X ) \begin{aligned} \mathcal{D}[Q(z)||P(z|X)] &= E_{z\sim Q}[\log Q(z) - \log \frac{P(z)P(X|z)}{P(X)}] \\ &= E_{z\sim Q}[\log Q(z) - \log P(z) - \log P(X|z) + \log P(X)] \\ &= E_{z\sim Q}[\log Q(z) - \log P(z) - \log P(X|z)] + \log P(X) \end{aligned} D[Q(z)∣∣P(zX)]=EzQ[logQ(z)logP(X)P(z)P(Xz)]=EzQ[logQ(z)logP(z)logP(Xz)+logP(X)]=EzQ[logQ(z)logP(z)logP(Xz)]+logP(X)

因为 log ⁡ P ( X ) \log P(X) logP(X)与变量 z z z无关,因此可以从期望中拿出,将上式继续整理得
log ⁡ P ( X ) − D [ Q ( z ) ∣ ∣ P ( z ∣ X ) ] = E z ∼ Q [ log ⁡ P ( z ) + log ⁡ P ( X ∣ z ) − log ⁡ Q ( z ) ] = E z ∼ Q [ log ⁡ P ( X ∣ z ) ] + E z ∼ Q [ log ⁡ P ( z ) − log ⁡ Q ( z ) ] = E z ∼ Q [ log ⁡ P ( X ∣ z ) ] − D [ Q ( z ) ∣ ∣ P ( z ) ] \begin{aligned} \log P(X) - \mathcal{D}[Q(z)||P(z|X)] &= E_{z\sim Q}[\log P(z) + \log P(X|z) - \log Q(z)] \\ &= E_{z\sim Q}[\log P(X|z)] + E_{z\sim Q}[\log P(z) - \log Q(z)] \\ &= E_{z\sim Q}[\log P(X|z)] - \mathcal{D}[Q(z) || P(z)] \end{aligned} logP(X)D[Q(z)∣∣P(zX)]=EzQ[logP(z)+logP(Xz)logQ(z)]=EzQ[logP(Xz)]+EzQ[logP(z)logQ(z)]=EzQ[logP(Xz)]D[Q(z)∣∣P(z)]

注意到这里的 Q ( z ) Q(z) Q(z)可以是任意的概率分布,但是为了让其有意义,使从其采样出的噪声更容易建模出 P ( X ) P(X) P(X),并最小化 D [ Q ( z ) ∣ ∣ P ( z ∣ X ) ] \mathcal{D}[Q(z)||P(z|X)] D[Q(z)∣∣P(zX)],我们让 Q ( z ) Q(z) Q(z)去依赖于 X X X,即得到下式
log ⁡ P ( X ) − D [ Q ( z ∣ X ) ∣ ∣ P ( z ∣ X ) ] = E z ∼ Q [ log ⁡ P ( X ∣ z ) ] − D [ Q ( z ∣ X ) ∣ ∣ P ( z ) ] \log P(X) - \mathcal{D}[Q(z|X)||P(z|X)] = E_{z\sim Q}[\log P(X|z)] - \mathcal{D}[Q(z|X) || P(z)] logP(X)D[Q(zX)∣∣P(zX)]=EzQ[logP(Xz)]D[Q(zX)∣∣P(z)]

上式就是VAE最核心的公式了

  • 等号左边: log ⁡ P ( X ) \log P(X) logP(X)是我们要最大化的项, P ( z ∣ X ) P(z|X) P(zX)描述了可能生成 X X X z z z的取值,当模型的建模能力足够时, D [ Q ( z ∣ X ) ∣ ∣ P ( z ∣ X ) ] \mathcal{D}[Q(z|X)||P(z|X)] D[Q(zX)∣∣P(zX)]可以看做0
  • 等号右边:可以通过梯度下降来优化,他更像是一个自编码器(AE),因为 Q ( z ∣ X ) Q(z|X) Q(zX) X X X编码到 z z z空间, P ( X ∣ z ) P(X|z) P(Xz) z z z解码重建出 X X X

未完…

你可能感兴趣的:(PyTorch,pytorch,python)