Improved Denoising Diffusion Probabilistic Models

文章仅供学习,内容源自视频58、Improved Diffusion的PyTorch代码逐行深入讲解_哔哩哔哩_bilibili

部分内容引用自

IDDPM论文阅读_deep unsupervised learning using nonequilibrium th_zzfive的博客-CSDN博客

IDDPM与DDPM区别:

        在Forward Process中DDPM不含参,IDDPM含参;

        Xt-1的方差,在DDPM中使用β,IDDPM预测β和真实方差的线性加权的权重;

        DDPM使用MSELoss,IDDPM使用hybrid loss,也就是将MSELoss与KL loss相加;

        训练时,DDPM的每个t是均匀采样,IDDPM使用非均匀采样。基于每一步t重要性。

改进实现:

使用余弦方案生成β。

Improved Denoising Diffusion Probabilistic Models_第1张图片

得到βt = 1 - f(t)/f(t-1) 

    def betas(self,n_steps,max_beta=0.999):
        # 余弦加噪方案生成Beta
        betas = []
        alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
        for i in range(n_steps):
            t1 = i / n_steps
            t2 = (i + 1) / n_steps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return torch.Tensor(betas)

 预测β和真实方差的线性加权的权重

在这里插入图片描述

在这里插入图片描述

Improved Denoising Diffusion Probabilistic Models_第2张图片

 上段部分引用于,文章仅用于学习IDDPM论文阅读_deep unsupervised learning using nonequilibrium th_zzfive的博客-CSDN博客

        简单总结一下,就是在DDPM中作者发现在扩散步数的增大,使用β作为方差和使用真实方差的区别不大。也就是说方差对样本质量并不重要,重要的是均值。

        但是对于前几个扩散步数就不一样了,因此还是要学习更好的方差以防止不稳定性。最后发现使用预测真实方差和β的线性加权权重是最好选择。为了能使权重得到学习,还需要重新设计损失函数。

代码实现:

eps_theta,module_var_values = torch.split(module_output,C,1) # 得到预测的分布和预测的系数
frac = ((module_var_values + 1) / 2)
model_log_variance = frac * torch.log2(beta) + (1 - frac) * torch.log2(beta_bar) # 线性相加
var = torch.exp(model_log_variance)

混合损失

        在上一部分也提到了混合损失,在源代码中1-T时刻用KL散度,0-1时刻用对数似然,但我这里全部都用了KL散度,因为对数似然那部分代码实在没看懂,概率论已经全部遗忘。。。有了解的朋友欢迎补充。

    def loss(self,x0,xt,t,noise):
        true_dis = self.q_posterior_sample(x0,xt,t)
        # 在源码中经过了两次模型预测。第一次得到p分布,第二次先将模型预测噪音部分的权重冻结,
        # 再重新传入得到预测的线性权重,这里改为一次得到,因此p_sample既返回预测的分布又返回预测的噪声
        pred_dis,eps_theta = self.p_sample(xt, t)
        # KL散度,这里没有像源代码一样使用L[0]的负对数似然
        kl_loss = dist.kl_divergence(true_dis, pred_dis)
        kl_loss = torch.mean(kl_loss) / torch.log(torch.tensor(2.0))
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if not noise:
            noise = torch.randn_like(x0)
        # kl_loss即Lvlb,它包含模型,也包含预测的线性权重
        return F.mse_loss(noise,eps_theta)+1e-3*kl_loss

最后关于t根据之前时刻的重要性采样我还没有弄懂,如果明白了我会补上。本文过于粗糙望各位见谅。

你可能感兴趣的:(python,开发语言)