参考资料:
[1]【54、Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读】 https://www.bilibili.com/video/BV1b541197HX/?share_source=copy_web&vd_source=7771b17ae75bc5131361e81a50a0c871
[2] https://t.bilibili.com/700526762586538024?spm_id_from=333.999.0.0
以下内容为对上述资料的补充理解,理解不对的地方,请多指教。
以下序号与资料中的章节序号一致。
扩散模型本质为生成模型,所以最本质的目标是最大化对数据分布真值的预测概率。
这里可以假设成一个分类问题,不同的类别表示不同的数据分布,其中包括与数据分布真值相近的和不相近的。模型会预测不同数据分布的概率。我们的目标是,使网络对数据分布真值对应的类别的预测概率最高。
用公式表示: m a x p θ ( x 0 ) max~p_{\theta}(x_0) max pθ(x0),其中, p θ ( x 0 ) p_{\theta}(x_0) pθ(x0)为模型对数据分布真值预测的概率分布(注意模型不只是网络,在扩散模型里,网络是模型的一部分,模型还包括对网络输出结果的后处理,因此网络输出值可能多种多样)。
但是 p θ ( x 0 ) p_{\theta}(x_0) pθ(x0)范围是 0 − 1 0-1 0−1,直接最大化不好计算,因此一般转化为最小化对数似然函数: − l o g p θ ( x 0 ) -log~p_{\theta}(x_0) −log pθ(x0)。直接最小化 − l o g p θ ( x 0 ) -log~p_{\theta}(x_0) −log pθ(x0)也不好求,所以扩散模型转而最小化 − l o g p θ ( x 0 ) -log~p_{\theta}(x_0) −log pθ(x0)的上界
,这个上界
就是 L V L B L_{VLB} LVLB(需要乘 q ( x 0 ) q(x_0) q(x0))。
下面的目标就是最小化 L V L B L_{VLB} LVLB。
L V L B L_{VLB} LVLB最终转化为 L V L B = E q [ L T + L t − 1 ] L_{VLB}=E_q[L_T+L_{t-1}] LVLB=Eq[LT+Lt−1]( L 0 L_0 L0与 L t − 1 L_{t-1} Lt−1合并到一起了),其中, L T L_T LT和 L t − 1 L_{t-1} Lt−1都是两个高斯分布的KL散度,结果只与两个高斯分布的均值和方差有关。 L T L_T LT中两个分布的均值和方差都是已知(在 x 0 x_0 x0分布已知的情况下已知)且不可优化的,因此直接去除。下面计算 L t − 1 L_{t-1} Lt−1,如下式(方差是设定的固定值,所以省略了):
其中, μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0)是 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0)高斯分布的均值, μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)是 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt−1∣xt)高斯分布的均值。
p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt−1∣xt)是模型的预测分布,也可以写成 p θ ( x t − 1 ∣ x t , t ) p_{\theta}(x_{t-1}|x_t, t) pθ(xt−1∣xt,t)。
对上式展开,其中 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0)的均值 μ ~ ( x t , x 0 ) \tilde\mu(x_t, x_0) μ~(xt,x0)已经在前面计算出来了,直接代入:
上式中 ϵ \epsilon ϵ与上文的 z z z一样,都是加的噪声。下面的问题是,我们要最小化 L t − 1 − C L_{t-1}-C Lt−1−C,网络在模型中扮演什么角色?可选择的是:
扩散模型的作者选择用网络来预测 ϵ \epsilon ϵ,这样, μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)的计算公式如下:
再简化 L t − 1 − C L_{t-1}-C Lt−1−C,如下:
到这里,网络的损失就确定了,即最小化预测的噪声
与实际添加的噪声
的差,网络输入是时刻t
和时刻t对应的xt
。
有了网络输出的噪声后,就可以通过 p θ ( x t − 1 ∣ x t , t ) p_{\theta}(x_{t-1}|x_t,t) pθ(xt−1∣xt,t)分布的均值 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)和方差(方差是预定义的 β \beta β)来采样出 x t − 1 x_{t-1} xt−1,训练过程和反扩散过程的伪代码如下:
反扩散过程用到了重参数化采样,上图中的 σ t \sigma_t σt就是标准差 β t \sqrt{\beta_t} βt。
是为了让扩散后的数据分布接近正态分布而特意设计的。