本篇博客不会很详细介绍diffusion model的原理,而是用“知其然”的方式直接上代码。
DenoisingDiffusionProbabilityModel-ddpm-
主要代码库为Diffusion
以及DiffusionFreeGuidence
。前者为最基础的Diffusion Model实现,后者则加上了最常用也最有效的技巧“Free Guidence”
├─Diffusion
│ │ Diffusion.py
│ │ Model.py
│ │ Train.py
│ │ __init__.py
│
├─DiffusionFreeGuidence
│ DiffusionCondition.py
│ ModelCondition.py
│ TrainCondition.py
│ __init__.py
Diffusion
Package首先导入需要的包,并定义一个“提取“函数extract
。
import torch
import torch.nn as nn
import torch.nn.functional as F
# ``extract``函数的作用是从v这一序列中按照索引t取出需要的数,然后reshape到输入数据x的维度
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
device = t.device
# ``torch.gather``的用法建议看https://zhuanlan.zhihu.com/p/352877584的第一条评论
# 在此处的所有调用实例中,v都是一维,可以看作是索引取值,即等价v[t], t大小为[batch_size, 1]
out = torch.gather(v, index=t, dim=0).float().to(device)
# 再把索引到的值reshape到[batch_size, 1, 1, ...], 维度和x_shape相同
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
必要公式及简单推导如下:
q ( x t ∣ x t − 1 ) ∼ N ( 1 − β t x t − 1 , β t ) q(x_t |x_{t-1})\sim N(\sqrt{1-\beta_t} x_{t-1}, \beta_t) q(xt∣xt−1)∼N(1−βtxt−1,βt)
x t = 1 − β t x t − 1 + β t ⋅ z t (1) x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt\beta_t\cdot z_t\tag{1} xt=1−βtxt−1+βt⋅zt(1)
翻译:条件概率分布 q ( x t ∣ x t − 1 ) q(x_t |x_{t-1}) q(xt∣xt−1) 服从均值为 1 − β t x t − 1 \sqrt{1-\beta_t} x_{t-1} 1−βtxt−1,方差为 β t \beta_t βt的正态分布(高斯分布), β t \beta_t βt是常数。写成递推等式就是公式(1),其中 z t z_t zt为标准正态分布(均值为0标准差为1)
1 − β t → α t 1 - \beta_t \rightarrow\alpha_t 1−βt→αt
q ( x t ∣ x t − 1 ) ∼ N ( α t x t − 1 , 1 − α t ) q(x_t |x_{t-1})\sim N(\sqrt\alpha_t x_{t-1}, {1-\alpha_t}) q(xt∣xt−1)∼N(αtxt−1,1−αt)
x t = α t x t − 1 + 1 − α t ⋅ z t (2) x_t=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\cdot z_t \tag{2} xt=αtxt−1+1−αt⋅zt(2)
翻译:令 1 − β t 为 α t 1-\beta_t为\alpha_t 1−βt为αt,条件概率分布 q ( x t ∣ x t − 1 ) q(x_t |x_{t-1}) q(xt∣xt−1) 服从均值为 α t x t − 1 \sqrt\alpha_t x_{t-1} αtxt−1,方差为 1 − α t {1-\alpha_t} 1−αt的正态分布。
q ( x t ∣ x 0 ) ∼ N ( α ˉ t x 0 , 1 − α ˉ t ) q(x_t|x_0)\sim N(\sqrt{\bar\alpha_t}x_0,{1-\bar\alpha_t}) q(xt∣x0)∼N(αˉtx0,1−αˉt)
其中 α ˉ t = ∏ i = 0 T α i \bar\alpha_t=\prod_{i=0}^T\alpha_i αˉt=∏i=0Tαi
x t = α ˉ t x 0 + 1 − α ˉ t ⋅ z t (3) x_t=\sqrt{\bar\alpha_t} x_0+\sqrt{1-\bar\alpha_t}\cdot z_t\tag{3} xt=αˉtx0+1−αˉt⋅zt(3)
翻译:将等式(2)中的递归形式转为直接从 x 0 x_0 x0计算 x t x_t xt,其中 α ˉ t \bar\alpha_t αˉt表示 α i \alpha_i αi的连乘。
# ``GaussianDiffusionTrainer``包含了Diffusion Model的前向过程(加噪) & 训练过程
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
"""
初始化前向模型
Args:
model: 骨干模型,主流为U-Net+Attention
beta_1: beta的起始值,本实例中取1e-4
beta_T: bata在t=T时的值,本实例中取0.2
T: 时间步数, 本实例中取1000
"""
super().__init__()
# 参数赋值
self.model = model
self.T = T
# 等间隔得到beta_1到beta_T之间共T个step对应的beta值,组成序列存为类成员(后边可以用``self.betas``访问)
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
# 根据公式,令alphas = 1 - betas
alphas = 1. - self.betas
# 根据公式,计算alpha连乘结果,存为alphas_bar
# ``torch.cumprod``用于计算一个序列每个数与其前面所有数连乘的结果,得到一个序列,长度等于原序列长度
# 例如:
# a = torch.tensor([2,3,1,4])
# b = torch.cumprod(a, dim=0)其实就等于torch.tensor([2, 2*3, 2*3*1, 2*3*1*4]) = torch.tensor([2, 6, 6, 24])
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
# 根据公式计算sqrt(alphas_bar)以及sqrt(1-alphas_bar)分别作为正向扩散的均值和标准差,存入类成员
# 可用``self.sqrt_alphas_bar``和``sqrt_one_minus_alphas_bar``来访问
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
Algorithm 1.
"""
# 从0~T中随机选batch_size个时间点
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
# 参数重整化技巧,先生成均值为0方差为1的高斯分布,再通过乘标准差加均值的方式用于间接采样
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
# 做一步反向扩散,希望模型可以预测出加入的噪声,也就是公式中的z_t
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
反向扩散过程的必要公式及其推导,我引用这篇文章的内容:
反向扩散相当于将条件概率颠倒,求 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt)
由条件概率的相关知识可得 q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) ⋅ q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_t)=q(x_t|x_{t-1})\cdot\frac{q(x_{t-1})}{q(x_t)} q(xt−1∣xt)=q(xt∣xt−1)⋅q(xt)q(xt−1)
但是我们在前向扩散时,每个状态需要由 x 0 x_0 x0求解,所以等式可以转换为:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) ⋅ q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_t, x_0)=q(x_t|x_{t-1}, x_0)\cdot\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} q(xt−1∣xt,x0)=q(xt∣xt−1,x0)⋅q(xt∣x0)q(xt−1∣x0)
根据正向过程(2)和(3)的式子,将上式右边三项分别展开:
q ( x t ∣ x t − 1 , x 0 ) = q ( x t ∣ x t − 1 ) ∼ N ( α t x t − 1 , 1 − α t ) ⇒ x t = α t x t − 1 + 1 − α t ⋅ z t q(x_t|x_{t-1},x_0)=q(x_t|x_{t-1})\sim N(\sqrt\alpha_t x_{t-1}, 1-\alpha_t)\Rightarrow x_t = \sqrt\alpha_tx_{t-1}+\sqrt{1-\alpha_t}\cdot z_t q(xt∣xt−1,x0)=q(xt∣xt−1)∼N(αtxt−1,1−αt)⇒xt=αtxt−1+1−αt⋅zt
q ( x t ∣ x 0 ) ∼ N ( α ˉ t x 0 , 1 − α ˉ t ) ⇒ x t = α ˉ t x 0 + 1 − α ˉ t ⋅ z t q(x_t|x_0)\sim N(\sqrt{\bar\alpha_t}x_0, {1-\bar\alpha_t})\Rightarrow x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\cdot z_t q(xt∣x0)∼N(αˉtx0,1−αˉt)⇒xt=αˉtx0+1−αˉt⋅zt
q ( x t − 1 ∣ x 0 ) ∼ N ( α ˉ t − 1 x 0 , 1 − α ˉ t − 1 ) ⇒ x t = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ⋅ z t − 1 q(x_{t-1}|x_0)\sim N(\sqrt{\bar\alpha_{t-1}}x_0, {1-\bar\alpha_{t-1}})\Rightarrow x_t=\sqrt{\bar\alpha_{t-1}}x_0+\sqrt{1-\bar\alpha_{t-1}}\cdot z_{t-1} q(xt−1∣x0)∼N(αˉt−1x0,1−αˉt−1)⇒xt=αˉt−1x0+1−αˉt−1⋅zt−1
其中 z 0 , z 1 , . . . , z t − 1 , z t z_0,z_1,...,z_{t-1}, z_t z0,z1,...,zt−1,zt均为标准正态分布
根据上面三个展开结果,将 q ( x t ∣ x t − 1 , x 0 ) q(x_t|x_{t-1},x_0) q(xt∣xt−1,x0)写成右侧三个正态分布合并后的展开公式,乘法在exp指数中变为加法,除法变为减法,汇总得:
q ( x t ∣ x t − 1 , x 0 ) ∝ e x p ( − 1 2 ⋅ ( ( x t − α t x t − 1 ) 2 1 − α t + ( x t − 1 − α ˉ t − 1 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t ) 2 1 − α ˉ t ) ) q(x_t|x_{t-1},x_0)\propto exp(-\frac12\cdot (\frac{(x_t-\sqrt\alpha_t x_{t-1})^2}{1-\alpha_t}+\frac{(x_{t-1}-\sqrt{\bar\alpha_{t-1}})^2}{1-\bar\alpha_{t-1}}-\frac{(x_t-\sqrt{\bar\alpha_t})^2}{1-\bar\alpha_t})) q(xt∣xt−1,x0)∝exp(−21⋅(1−αt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1)2−1−αˉt(xt−αˉt)2))
展开后集中 x t − 1 x_{t-1} xt−1合并同类项得 e x p ( − 1 2 ⋅ ( ( α t 1 − α t + 1 1 − α ˉ t − 1 ) ⋅ x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) exp(-\frac12\cdot((\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar\alpha_{t-1}})\cdot x_{t-1}^2- (\frac{2\sqrt\alpha_t}{\beta_t}x_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x_0)x_{t-1}+C(x_t,x_0))) exp(−21⋅((1−αtαt+1−αˉt−11)⋅xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))
我们将正态分布标准公式展开: e x p ( − 1 2 ⋅ ( x − μ ) 2 σ 2 ) = e x p ( − 1 2 ⋅ ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) exp(-\frac12\cdot\frac{(x-\mu)^2}{\sigma^2})=exp(-\frac12\cdot(\frac1{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2})) exp(−21⋅σ2(x−μ)2)=exp(−21⋅(σ21x2−σ22μx+σ2μ2))
与上面公式一 一对应可以得到:
1 σ 2 = α t 1 − α t + 1 1 − α ˉ t − 1 \frac1{\sigma^2}=\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar\alpha_{t-1}} σ21=1−αtαt+1−αˉt−11
2 μ σ 2 = 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 \frac{2\mu}{\sigma^2}=\frac{2\sqrt\alpha_t}{\beta_t}x_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x_0 σ22μ=βt2αtxt+1−αˉt−12αˉt−1x0
根据 β t = 1 − α t \beta_t=1-\alpha_t βt=1−αt以及 α ˉ t = α t ⋅ α ˉ t − 1 \bar\alpha_t=\alpha_t\cdot\bar\alpha_{t-1} αˉt=αt⋅αˉt−1化简得:
σ 2 = β t ⋅ ( 1 − α ˉ t − 1 ) 1 − α ˉ t (4) \sigma^2=\frac{\beta_t\cdot(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\tag{4} σ2=1−αˉtβt⋅(1−αˉt−1)(4)
μ = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 (5) \mu=\frac{\sqrt\alpha_t(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0\tag{5} μ=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0(5)
又因为公式中的 x 0 x_0 x0就是我们要求解的,需要被已知量替换掉,根据前向过程可知 x t = α ˉ t x 0 + 1 − α ˉ t z t x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}z_t xt=αˉtx0+1−αˉtzt,
所以有 x 0 = x t − 1 − α ˉ t z t α ˉ t x_0=\frac{x_t-\sqrt{1-\bar\alpha_t}z_t}{\sqrt{\bar\alpha_t}} x0=αˉtxt−1−αˉtzt,将其带入到公式(5)中得到:
μ = 1 α t x t − 1 α t ⋅ 1 − α t 1 − α ˉ t z t (6) \mu=\frac{1}{\sqrt\alpha_t}x_t-\frac{1}{\sqrt\alpha_t}\cdot \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}z_t\tag{6} μ=αt1xt−αt1⋅1−αˉt1−αtzt(6)
为了代码上计算方便,将均值 μ \mu μ的计算公式分为两个系数,即 μ = c 1 ⋅ x t − c 2 ⋅ x 0 \mu=c_1\cdot x_t-c_2\cdot x_0 μ=c1⋅xt−c2⋅x0,其中 c 1 , c 2 c_1, c_2 c1,c2如下:
c 1 = 1 α t (7) c_1=\frac{1}{\sqrt{\alpha_t}}\tag{7} c1=αt1(7)
c 2 = c 1 ⋅ 1 − α t 1 − α ˉ t (8) c_2=c_1\cdot\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}\tag{8} c2=c1⋅1−αˉt1−αt(8)
# ``GaussianDiffusionSampler``包含了Diffusion Model的后向过程 & 推理过程
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
"""
所有参数含义和``GaussianDiffusionTrainer``(前向过程)一样
"""
super().__init__()
self.model = model
self.T = T
# 这里获取betas, alphas以及alphas_bar和前向过程一模一样
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# 这一步是方便后面运算,相当于构建alphas_bar{t-1}
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] # 把alpha_bar的第一个数字换成1,按序后移
# 根据公式(7)(8),后向过程中的计算均值需要用到的系数用coeff1和coeff2表示
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
# 根据公式(4),计算后向过程的方差
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
"""
该函数用于反向过程中,条件概率分布q(x_{t-1}|x_t)的均值
Args:
x_t: 迭代至当前步骤的图像
t: 当前步数
eps: 模型预测的噪声,也就是z_t
Returns:
x_{t-1}的均值,mean = coeff1 * x_t + coeff2 * eps
"""
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
"""
该函数用于反向过程中,计算条件概率分布q(x_{t-1}|x_t)的均值和方差
Args:
x_t: 迭代至当前步骤的图像
t: 当前步数
Returns:
xt_prev_mean: 均值
var: 方差
"""
# below: only log_variance is used in the KL computations
# 这一步我略有不解,为什么要把算好的反向过程的方差大部分替换成betas。
# 我猜测,后向过程方差``posterior_var``的计算过程仅仅是betas乘上一个(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t}),
# 由于1 - alpha_bar_{t}这个数值非常趋近于0,分母为0会导致nan,
# 而整体(1 - alpha_bar_{t-1}) / (1 - alpha_bar_{t})非常趋近于1,所以直接用betas近似后向过程的方差,
# 但是t = 1 的时候(1 - alpha_bar_{0}) / (1 - alpha_bar_{1})还不是非常趋近于1,所以这个数值要保留,
# 因此就有拼接``torch.cat([self.posterior_var[1:2], self.betas[1:]])``这一步
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
# 模型前向预测得到eps(也就是z_t)
eps = self.model(x_t, t)
# 计算均值
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2.
"""
# 反向扩散过程,从x_t迭代至x_0
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
# t = [1, 1, ....] * time_step, 长度为batch_size
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
# 计算条件概率分布q(x_{t-1}|x_t)的均值和方差
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
# 最后一步的高斯噪声设为0(我认为不设为0问题也不大,就本实例而言,t=0时的方差已经很小了)
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
# ``torch.clip(x_0, -1, 1)``,把x_0的值限制在-1到1之间,超出部分截断
return torch.clip(x_0, -1, 1)
请直接阅读代码注释:
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class Swish(nn.Module):
"""
定义swish激活函数,可参考https://blog.csdn.net/bblingbbling/article/details/107105648
"""
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
"""
定义``时间嵌入``模块
"""
def __init__(self, T, d_model, dim):
"""
初始的time-embedding是由一系列不同频率的正弦、余弦函数采样值表示,
即:[[sin(w_0*x), cos(w_0*x)],
[sin(w_1*x), cos(w_1*x)],
...,
[sin(w_T)*x, cos(w_T*x)]], 维度为 T * d_model
在本实例中,频率范围是[0:T], x在1e-4~1范围,共d_model // 2个离散点;将sin, cos并在一起组成d_model个离散点
Args:
T: int, 总迭代步数,本实例中T=1000
d_model: 输入维度(通道数/初始embedding长度)
dim: 输出维度(通道数)
"""
assert d_model % 2 == 0
super().__init__()
# 前两行计算x向量,共64个点
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
# T个时间位置组成频率部分
pos = torch.arange(T).float()
# 两两相乘构成T*(d_model//2)的矩阵,并assert形状
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
# 计算不同频率sin, cos值,判断形状,并reshape到T*d_model
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
# MLP层,通过初始编码计算提取特征后的embedding
# 包含两个线性层,第一个用swish激活函数,第二个不使用激活函数
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
class DownSample(nn.Module):
"""
通过stride=2的卷积层进行降采样
"""
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
x = self.main(x)
return x
class UpSample(nn.Module):
"""
通过conv+最近邻插值进行上采样
"""
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x
class AttnBlock(nn.Module):
"""
自注意力模块,其中线性层均用kernel为1的卷积层表示
"""
def __init__(self, in_ch):
# ``self.proj_q``, ``self.proj_k``, ``self.proj_v``分别用于学习query, key, value
# ``self.proj``作为自注意力后的线性投射层
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.initialize()
def initialize(self):
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)
def forward(self, x):
B, C, H, W = x.shape
# 输入经过组归一化以及全连接层后分别得到query, key, value
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
# 用矩阵乘法计算query与key的相似性权重w
# 其中的``torch.bmm``的效果是第1维不动,第2,3维的矩阵做矩阵乘法,
# 如a.shape=(_n, _h, _m), b.shape=(_n, _m, _w) --> torch.bmm(a, b).shape=(_n, _h, _w)
# 矩阵运算后得到的权重要除以根号C, 归一化(相当于去除通道数对权重w绝对值的影响)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
# 再用刚得到的权重w对value进行注意力加权,操作也是一次矩阵乘法运算
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
# 最后经过线性投射层输出,返回值加上输入x构成跳跃连接(残差连接)
h = self.proj(h)
return x + h
class ResBlock(nn.Module):
"""
残差网络模块
"""
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
"""
Args:
in_ch: int, 输入通道数
out_ch: int, 输出通道数
tdim: int, time-embedding的长度/维数
dropout: float, dropout的比例
attn: bool, 是否使用自注意力模块
"""
super().__init__()
# 模块1: gn -> swish -> conv
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
# time_embedding 映射层: swish -> fc
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
# 模块2: gn -> swish -> dropout -> conv
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
# 如果输入输出通道数不一样,则添加一个过渡层``shortcut``, 卷积核为1, 否则什么也不做
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
# 如果需要加attention, 则添加一个``AttnBlock``, 否则什么也不做
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self, x, temb):
h = self.block1(x) # 输入特征经过模块1编码
h += self.temb_proj(temb)[:, :, None, None] # 将time-embedding加入到网络
h = self.block2(h) # 将混合后的特征输入到模块2进一步编码
h = h + self.shortcut(x) # 残差连接
h = self.attn(h) # 经过自注意力模块(如果attn=True的话)
return h
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
"""
Args:
T: int, 总迭代步数,本实例中T=1000
ch: int, UNet第一层卷积的通道数,每下采样一次在这基础上翻倍, 本实例中ch=128
ch_mult: list, UNet每次下采样通道数翻倍的乘数,本实例中ch_mult=[1,2,3,4]
attn: list, 表示在第几次降采样中使用attention
num_res_blocks: int, 降采样或者上采样中每一层次的残差模块数目
dropout: float, dropout比率
"""
super().__init__()
# assert确保需要加attention的位置小于总降采样次数
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
# 将time-embedding从长度为ch初始化编码到tdim = ch * 4
tdim = ch * 4
# 实例化初始的time-embedding层
self.time_embedding = TimeEmbedding(T, ch, tdim)
# 实例化头部卷积层
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
# 实例化U-Net的编码器部分,即降采样部分,每一层次由``num_res_blocks``个残差块组成
# 其中chs用于记录降采样过程中的各阶段通道数,now_ch表示当前阶段的通道数
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult): # i表示列表ch_mult的索引, mult表示ch_mult[i]
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
# 实例化U-Net编码器和解码器的过渡层,由两个残差块组成
# 这里我不明白为什么第一个残差块加attention, 第二个不加……问就是``工程科学``
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
# 实例化U-Net的解码器部分, 与编码器几乎对称
# 唯一不同的是,每一层次的残差块比编码器多一个,
# 原因是第一个残差块要用来融合当前特征图与跳转连接过来的特征图,第二、三个才是和编码器对称用来抽特征
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
# 尾部模块: gn -> swish -> conv, 目的是回到原图通道数
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
# 注意这里只初始化头部和尾部模块,因为其他模块在实例化的时候已经初始化过了
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
if __name__ == '__main__':
batch_size = 8
model = UNet(
T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
num_res_blocks=2, dropout=0.1)
x = torch.randn(batch_size, 3, 32, 32)
t = torch.randint(1000, (batch_size, ))
y = model(x, t)
print(y.shape)
这一部分和普通模型训练、验证没有太大区别,我在重要的地方写上注释,请直接阅读代码注释:
import os
from typing import Dict
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
from Diffusion.Model import UNet
from Scheduler import GradualWarmupScheduler
def train(modelConfig: Dict):
device = torch.device(modelConfig["device"])
# dataset
dataset = CIFAR10(
root='./CIFAR10', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = DataLoader(
dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
# model setup
net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
if modelConfig["training_load_weight"] is not None:
net_model.load_state_dict(torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
optimizer = torch.optim.AdamW(
net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
# 设置学习率衰减,按余弦函数的1/2个周期衰减,从``lr``衰减至0
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
# 设置逐步预热调度器,学习率从0逐渐增加至multiplier * lr,共用1/10总epoch数,后续学习率按``cosineScheduler``设置进行变化
warmUpScheduler = GradualWarmupScheduler(
optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
# 实例化训练模型
trainer = GaussianDiffusionTrainer(
net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
# start training
for e in range(modelConfig["epoch"]):
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, labels in tqdmDataLoader:
# train
optimizer.zero_grad() # 清空过往梯度
x_0 = images.to(device) # 将输入图像加载到计算设备上
loss = trainer(x_0).sum() / 1000. # 前向传播并计算损失
loss.backward() # 反向计算梯度
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), modelConfig["grad_clip"]) # 裁剪梯度,防止梯度爆炸
optimizer.step() # 更新参数
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": e,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
}) # 设置进度条显示内容
warmUpScheduler.step() # 调度器更新学习率
torch.save(net_model.state_dict(), os.path.join(
modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt")) # 保存模型
def eval(modelConfig: Dict):
# load model and evaluate
with torch.no_grad():
# 建立和加载模型
device = torch.device(modelConfig["device"])
model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
ckpt = torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
model.load_state_dict(ckpt)
print("model load weight done.")
# 实例化反向扩散采样器
model.eval()
sampler = GaussianDiffusionSampler(
model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
# Sampled from standard normal distribution
# 随机生成高斯噪声图像并保持
noisyImage = torch.randn(
size=[modelConfig["batch_size"], 3, 32, 32], device=device)
saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
save_image(saveNoisy, os.path.join(
modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
# 反向扩散并保存输出图像
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
save_image(sampledImgs, os.path.join(
modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
DiffusionFreeGuidence
Package原文:《Classifier-Free Diffusion Guidance》
建议和上边代码对比着看。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
device = t.device
out = torch.gather(v, index=t, dim=0).float().to(device)
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
"""
前向加噪过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionTrainer``几乎完全一样
不同点在于模型输入,除了需要输入``x_t``, ``t``, 还要输入条件``labels``
"""
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0, labels):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none') # 不同点在于模型的输入多了``labels``
return loss
class GaussianDiffusionSampler(nn.Module):
"""
反向扩散过程和``Diffusion.Diffusion.py``中的``GaussianDiffusionSampler``绝大部分一样,
所以在此只说明不一样的点
"""
def __init__(self, model, beta_1, beta_T, T, w=0.):
super().__init__()
self.model = model
self.T = T
# In the classifier free guidence paper, w is the key to control the gudience.
# w = 0 and with label = 0 means no guidence.
# w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger.
# 不同点1: 在初始化时需要输入一个权重系数``w``, 用来控制条件的强弱程度
self.w = w
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t, labels):
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
# 不同点2: 模型推理时需要计算有条件和无条件(随机噪声)情况下模型的输出,
# 将两次输出的结果用权重``self.w``进行合并得到最终输出
eps = self.model(x_t, t, labels)
nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
# 参考原文公式(6)
eps = (1. + self.w) * eps - self.w * nonEps
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T, labels):
"""
Algorithm 2.
"""
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
# 除了输入多一个``labels``其他都和普通Diffusion Model一样
mean, var = self.p_mean_variance(x_t=x_t, t=t, labels=labels)
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
import math
from telnetlib import PRAGMA_HEARTBEAT
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
def drop_connect(x, drop_ratio):
"""
这个函数在整个Project中都没被用到, 暂时先不考虑它的功能
"""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(p=keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
"""
和``Diffusion.Model``中的``TimeEmbedding``一模一样
"""
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb, freeze=False),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
def forward(self, t):
emb = self.timembedding(t)
return emb
class ConditionalEmbedding(nn.Module):
"""
这是一个条件编码模块,将condition编码为embedding
除了初始化Embedding不同,其他部分与time-embedding无异。
"""
def __init__(self, num_labels, d_model, dim):
assert d_model % 2 == 0
super().__init__()
# 注意,这里在初始化embedding时有一个细节——``num_embeddings=num_labels+1``也就是10+1=11
# 本实例中考虑的condition是CIFAR10的label,共10个类别,对应0~9,按理来说只需要10个embedding即可,
# 但是我们需要给``无条件``情况一个embedding表示,在本实例中就是用``0```来表示,
# 与此同时10个类别对应的标号分别加一,即1~10(会在``TrainCondition.py``中体现), 因此共需要11个embedding
self.condEmbedding = nn.Sequential(
nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
def forward(self, labels):
cemb = self.condEmbedding(labels)
return cemb
class DownSample(nn.Module):
"""
相比于``Diffusion.Model.DownSample``, 这里的降采样模块多加了一个5x5、stride=2的conv层
前向过程由3x3和5x5卷积输出相加得来,不知为什么这么做,可能为了融合更多尺度的信息
查看原文(4.Experiments 3~4行),原文描述所使用的模型与《Diffusion Models Beat GANs on Image Synthesis》所用模型一致,
但是该文章源码并没有使用这种降采样方式,只是简单的3x3或者avg_pool
"""
def __init__(self, in_ch):
super().__init__()
self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)
def forward(self, x, temb, cemb):
x = self.c1(x) + self.c2(x)
return x
class UpSample(nn.Module):
"""
相比于``Diffusion.Model.UpSample``, 这里的上采样模块使用反卷积而不是最近邻插值
同``DownSample``也不明白原因,因该两种方式都可以,看个人喜好。
"""
def __init__(self, in_ch):
super().__init__()
self.c = nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1)
self.t = nn.ConvTranspose2d(in_ch, in_ch, kernel_size=5, stride=2, padding=2, output_padding=1)
def forward(self, x, temb, cemb):
_, _, H, W = x.shape
x = self.t(x)
x = self.c(x)
return x
class AttnBlock(nn.Module):
"""
和``Diffusion.Model``中的``AttnBlock``一模一样
"""
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)
return x + h
class ResBlock(nn.Module):
"""
相比于``Diffusion.Model.ResBlock``, 这里的残差模块多加了一个条件投射层``self.cond_proj``,
在这里其实可以直接把它看作另一个time-embedding, 它们参与训练的方式一模一样
"""
def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.cond_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
def forward(self, x, temb, cemb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None] # 加上time-embedding
h += self.cond_proj(cemb)[:, :, None, None] # 加上conditional-embedding
h = self.block2(h) # 特征融合
h = h + self.shortcut(x)
h = self.attn(h)
return h
class UNet(nn.Module):
"""
相比于``Diffusion.Model.UNet``, 这里的UNet模块就多加了一个``cond_embedding``,
还有一个变化是在降采样和上采样阶段没有加自注意力层,只在中间过度的时候加了一次,这我不明白是和用意,
可能是希望网络不要从自己身上学到太多,多关注condition?(我瞎猜的)
"""
def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
super().__init__()
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
def forward(self, x, t, labels):
# Timestep embedding
temb = self.time_embedding(t)
cemb = self.cond_embedding(labels)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb, cemb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb, cemb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb, cemb)
h = self.tail(h)
assert len(hs) == 0
return h
if __name__ == '__main__':
batch_size = 8
model = UNet(
T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2],
num_res_blocks=2, dropout=0.1)
x = torch.randn(batch_size, 3, 32, 32)
t = torch.randint(1000, size=[batch_size])
labels = torch.randint(10, size=[batch_size])
# resB = ResBlock(128, 256, 64, 0.1)
# x = torch.randn(batch_size, 128, 32, 32)
# t = torch.randn(batch_size, 64)
# labels = torch.randn(batch_size, 64)
# y = resB(x, t, labels)
y = model(x, t, labels)
print(y.shape)
更新中……