扩散模型原理记录

一 扩散模型原理记录

参考资料:

[1]【54、Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读】 https://www.bilibili.com/video/BV1b541197HX/?share_source=copy_web&vd_source=7771b17ae75bc5131361e81a50a0c871

[2] https://t.bilibili.com/700526762586538024?spm_id_from=333.999.0.0

以下内容为对上述资料的补充理解,理解不对的地方,请多指教。

以下序号与资料中的章节序号一致。

七、目标数据分布的似然函数

扩散模型本质为生成模型,所以最本质的目标是最大化对数据分布真值的预测概率

这里可以假设成一个分类问题,不同的类别表示不同的数据分布,其中包括与数据分布真值相近的和不相近的。模型会预测不同数据分布的概率。我们的目标是,使网络对数据分布真值对应的类别的预测概率最高。

用公式表示: m a x   p θ ( x 0 ) max~p_{\theta}(x_0) max pθ(x0),其中, p θ ( x 0 ) p_{\theta}(x_0) pθ(x0)为模型对数据分布真值预测的概率分布(注意模型不只是网络,在扩散模型里,网络是模型的一部分,模型还包括对网络输出结果的后处理,因此网络输出值可能多种多样)。

但是 p θ ( x 0 ) p_{\theta}(x_0) pθ(x0)范围是 0 − 1 0-1 01,直接最大化不好计算,因此一般转化为最小化对数似然函数 − l o g   p θ ( x 0 ) -log~p_{\theta}(x_0) log pθ(x0)。直接最小化 − l o g   p θ ( x 0 ) -log~p_{\theta}(x_0) log pθ(x0)也不好求,所以扩散模型转而最小化 − l o g   p θ ( x 0 ) -log~p_{\theta}(x_0) log pθ(x0)上界,这个上界就是 L V L B L_{VLB} LVLB(需要乘 q ( x 0 ) q(x_0) q(x0))。

下面的目标就是最小化 L V L B L_{VLB} LVLB

L V L B L_{VLB} LVLB最终转化为 L V L B = E q [ L T + L t − 1 ] L_{VLB}=E_q[L_T+L_{t-1}] LVLB=Eq[LT+Lt1] L 0 L_0 L0 L t − 1 L_{t-1} Lt1合并到一起了),其中, L T L_T LT L t − 1 L_{t-1} Lt1都是两个高斯分布的KL散度,结果只与两个高斯分布的均值和方差有关。 L T L_T LT中两个分布的均值和方差都是已知(在 x 0 x_0 x0分布已知的情况下已知)且不可优化的,因此直接去除。下面计算 L t − 1 L_{t-1} Lt1,如下式(方差是设定的固定值,所以省略了):

扩散模型原理记录_第1张图片

其中, μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0) q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0)高斯分布的均值, μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t) p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)高斯分布的均值。

p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)是模型的预测分布,也可以写成 p θ ( x t − 1 ∣ x t , t ) p_{\theta}(x_{t-1}|x_t, t) pθ(xt1xt,t)

对上式展开,其中 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0)的均值 μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0)已经在前面计算出来了,直接代入:

扩散模型原理记录_第2张图片

上式中 ϵ \epsilon ϵ与上文的 z z z一样,都是加的噪声。下面的问题是,我们要最小化 L t − 1 − C L_{t-1}-C Lt1C,网络在模型中扮演什么角色?可选择的是:

  • 预测 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t),使其逼近 μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0),即损失是他俩的差;
  • 预测 x 0 ′ x_0' x0,使其直接逼近 x 0 x_0 x0,损失是他俩的差;
  • 预测 ϵ \epsilon ϵ,这样 p θ ( x t − 1 ∣ x t , t ) p_{\theta}(x_{t-1}|x_t,t) pθ(xt1xt,t)分布的均值 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)就与 q ( x t − 1 ∣ x t , x 0 ′ ) q(x_{t-1}|x_t,x_0') q(xt1xt,x0)的均值公式一样,即下式。这样就可以逼近 μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0),即损失是他俩的差(可以简化计算);
扩散模型原理记录_第3张图片

扩散模型的作者选择用网络来预测 ϵ \epsilon ϵ,这样, μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)的计算公式如下:

再简化 L t − 1 − C L_{t-1}-C Lt1C,如下:

扩散模型原理记录_第4张图片

到这里,网络的损失就确定了,即最小化预测的噪声实际添加的噪声的差,网络输入是时刻t时刻t对应的xt

有了网络输出的噪声后,就可以通过 p θ ( x t − 1 ∣ x t , t ) p_{\theta}(x_{t-1}|x_t,t) pθ(xt1xt,t)分布的均值 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)和方差(方差是预定义的 β \beta β)来采样出 x t − 1 x_{t-1} xt1,训练过程和反扩散过程的伪代码如下:

扩散模型原理记录_第5张图片

反扩散过程用到了重参数化采样,上图中的 σ t \sigma_t σt就是标准差 β t \sqrt{\beta_t} βt

二 问题记录

2.1 正向扩散过程的高斯均值和方差为什么这么设计?

是为了让扩散后的数据分布接近正态分布而特意设计的。

你可能感兴趣的:(生成模型,扩散模型,生成模型)