扩散模型(Diffusion Models)在不同的领域和文献中可能有不同的名称。其中一些常见的名称包括去噪扩散概率模型(ddpm)、基于分数的生成模型(score-based generative models)和生成扩散过程(generative diffusion processes)等等。另外,有些人也将它们称为基于能量的模型(EBMs),从技术上来说,它们可以被归类为这个模型家族的一个特例。但是最贴切的描述应该是**基于分数匹配模型(score-based generative models)**的想法,并使用了基于随机微分方程(SDEs)的形式主义。
Diffusers 的核心 API 被分为三个主要部分:
from diffusers import DDPMPipeline
def make_grid(images, size=64):
"""Given a list of PIL images, stack them together into a line for easy viewing"""
output_im = Image.new("RGB", (size * len(images), size))
for i, im in enumerate(images):
output_im.paste(im.resize((size, size)), (i * size, 0))
return output_im
# Load the butterfly pipeline
butterfly_pipeline = DDPMPipeline.from_pretrained(
# Create 8 images
images = butterfly_pipeline(batch_size=8).images
# View the result
这里介绍一个很简单的控制加噪的数量,我们设一个公式:$ (1-amount)x + amountnoise$,这个公式是想说如果amount = 0,我们将返回输入而不进行任何更改。如果amount 达到 1,我们会得到返回噪声,而没有输入 x 的痕迹。试想,如果amount是一个小数,那么它就是在有和无之间,类似于颜色中的不透明度,所以加百分之几的噪声这个公式就能很简单的帮我们实现。
通过以这种方式将输入与噪声混合,我们可以将输出保持在相同的范围内(0 到 1)。
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Plotting the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');
amount是干什么的?我们使用torch.linspace将0到1平分成8份,因为我们有8张输入的图像,这样可以看到每一张图像从无到逐渐加噪到完整噪声是怎么样的。amout的维度应该是(8),值为tensor([0.0000, 0.1429, 0.2857, 0.4286, 0.5714, 0.7143, 0.8571, 1.0000]),这样它的维度和x不一样,那么我们就要广播来计算了。广播后的维度是(8,1,1,1)。
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
这里的示例代码是一个简单模型,接收了一个28px的单色图像。我们设x = torch.rand(8, 1, 28, 28)
第1次输入: torch.Size([8, 1, 28, 28])
第1次卷积: torch.Size([8, 32, 28, 28])
第1次激活函数: torch.Size([8, 32, 28, 28])
第1下采样: torch.Size([8, 32, 14, 14])
第2次输入: torch.Size([8, 32, 14, 14])
第2次卷积: torch.Size([8, 64, 14, 14])
第2次激活函数: torch.Size([8, 64, 14, 14])
第2下采样: torch.Size([8, 64, 7, 7])
第3次输入: torch.Size([8, 64, 7, 7])
第3次卷积: torch.Size([8, 64, 7, 7])
第3次激活函数: torch.Size([8, 64, 7, 7])
第1输入: torch.Size([8, 64, 7, 7])
第1卷积: torch.Size([8, 64, 7, 7])
第1激活函数后: torch.Size([8, 64, 7, 7])
第2输入: torch.Size([8, 64, 7, 7])
第2上采样: torch.Size([8, 64, 14, 14])
第2跳跃后: torch.Size([8, 64, 14, 14])
第2卷积: torch.Size([8, 32, 14, 14])
第2激活函数后: torch.Size([8, 32, 14, 14])
第3输入: torch.Size([8, 32, 14, 14])
第3上采样: torch.Size([8, 32, 28, 28])
第3跳跃后: torch.Size([8, 32, 28, 28])
第3卷积: torch.Size([8, 1, 28, 28])
第3激活函数后: torch.Size([8, 1, 28, 28])
该网络有sum([p.numel() for p in net.parameters()])
# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# How many runs through the data should we do?
n_epochs = 3
# Create the network
net = BasicUNet()
# Our loss function
loss_fn = nn.MSELoss()
# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Keeping a record of the losses for later viewing
losses = []
# The training loop
for epoch in range(n_epochs):
for x, y in train_dataloader:
# Get some data and prepare the corrupted version
x = x.to(device) # Data on the GPU
noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
noisy_x = corrupt(x, noise_amount) # Create our noisy x
# Get the model prediction
pred = net(noisy_x)
# Calculate the loss
loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
# Backprop and update the params:
# Store the loss for later
# Print our the average of the loss values for this epoch:
avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
# View the loss curve
plt.ylim(0, 0.1);
#@markdown Visualizing model predictions on noisy inputs:
# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting
# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Get the model predictions
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()
# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');
第二次继续学习,假设我们就掌握了知识的 2 5 \frac{2}{5} 52,那么对于模糊处我们还有3/5没有明白。
我们慢慢学习,第一次学习后,我们对一个知识掌握了 4 5 \frac{4}{5} 54,那么你的模糊程度就从1变成了原来的 1 5 \frac{1}{5} 51,注意这里模糊程度是你从不理解变成理解的程度,原来的 1 5 \frac{1}{5} 51已经理解好多了。
这时候我们来衡量一下我们下面要学的知识 x 1 x_1 x1= 4 5 x + 1 5 p r e d \frac{4}{5}x+\frac{1}{5}pred 54x+51pred,这是什么意思呢?知识是一种累加过程,你也可以认为,我们第一次学习的知识,在后面的学习还要用到的,然而后面的 1 5 \frac{1}{5} 51不理解,你也要学习。
第二次学习,我们要转向这个新的知识了,同理现在的知识存量应该是 x 2 x_2 x2= 3 4 x 1 + 1 4 p r e d 1 \frac{3}{4}x_1+\frac{1}{4}pred_1 43x1+41pred1,这个知识可能更难了。
依次进行,直到第五次,对 x 5 x_5 x5,模糊程度可能还有,但总比第一次学习时候清楚了。
#@markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): # No need to track gradients during inference
pred = net(x) # Predict the denoised x0
pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
mix_factor = 1/(n_steps - i) # How much we move towards the prediction
x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
step_history.append(x.detach().cpu()) # Store step for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
比我们的BasicUNet更先进UNet2DModel 模型比上面的基本 UNet 有许多改进:
model = UNet2DModel(
sample_size=28, # the target image resolution
in_channels=1, # the number of input channels, 3 for RGB images
out_channels=1, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D", # a regular ResNet upsampling block
对应下采样模块 (下图中绿色部分), 而up_block_types
对应上采样模块 (下图中红色部分):
图来自 DDPM 论文 (https://arxiv.org/abs/2006.11239)。
DDPM论文描述了一个为每个“timestep”添加少量噪声的损坏过程。 为某些timestep给定 x t − 1 x_{t-1} xt−1 ,我们可以得到一个噪声稍稍增加的 x t x_t xt:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x} _{t-1}, \beta_t\mathbf{I}) \quad q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q (\mathbf{x}_t \vert \mathbf{x}_{t-1}) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)这表示在给定前一个timestep的状态 x t − 1 x_{t-1} xt−1 的条件下,当前timestep的状态 x t x_{t} xt 服从均值为$\sqrt{1 - \beta_t} \mathbf{x} _{t-1} ,方差为 ,方差为 ,方差为\beta_t\mathbf{I} 的多元正态分布。这里的 的多元正态分布。这里的 的多元正态分布。这里的\beta_t$通常是一个表示系统动力学特性或噪声程度的参数。
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q (\mathbf{x}_t \vert \mathbf{x}_{t-1}) q(x1:T∣x0)=∏t=1Tq(xt∣xt−1)这表示整个状态序列 x 1 : T \mathbf{x}_{1:T} x1:T在给定初始状态 x 0 x_0 x0的条件下,可以被分解为每一时刻的条件概率的连乘积。换句话说,这个公式表示整个状态序列的联合分布可以被分解为各个时刻的条件分布的连乘积,这正是马尔可夫链的性质所在,即当前状态仅依赖于前一个状态。
我们给 x t − 1 x_{t-1} xt−1 一个 1 − β t \sqrt{1 - \beta_t} 1−βt 系数,然后加上带有 β t \beta_t βt 系数的噪声 。这个 β \beta β 是根据调度器为每个 t 定义的,决定每一个迭代周期中添加多少噪声。
但上面的公式你会发现,我们要计算好多次,从 x 1 x_1 x1算到 x 2 x_2 x2一直算到 x n x_n xn。所以我们用一个一步到位的公式:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) \begin{aligned}q(\mathbf{x}_t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I})\end{aligned} q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I) where α ˉ t = ∏ i = 1 T α i \bar{\alpha}_t = \prod_{i=1}^T \alpha_i αˉt=∏i=1Tαi and α i = 1 − β i \alpha_i = 1-\beta_i αi=1−βi
这个公式表示了在给定初始状态 x 0 x_0 x0的条件下,当前tiemstept的状态 x t x_t xt的概率分布。具体来说,它表示 x t x_t xt 在均值为 α ˉ t x 0 \sqrt{\bar{\alpha}_t} \mathbf{x}_0 αˉtx0,方差为 ( 1 − α ˉ t ) I \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I} (1−αˉt)I 的多元正态分布中的概率分布。其中, α ˉ t \bar{\alpha}_t αˉt是一个关于时间的累积参数,由所有时间步长内的 α i \alpha_i αi乘积得到,而 α i \alpha_i αi则是与 β i \beta_i βi相关的参数。
在迭代过程中, α ˉ t \sqrt{\bar{\alpha}_t} αˉt越来越小, ( 1 − α ˉ t ) \sqrt{(1 - \bar{\alpha}_t)} (1−αˉt)越来越大。也就是说噪声越来越大了。
#@markdown visualize the DDPM noising process for different timesteps:
# Noise a batch of images to view the effect
fig, axs = plt.subplots(3, 1, figsize=(16, 10))
xb, yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb * 2. - 1. # Map to (-1, 1)
print('X shape', xb.shape)
# Show clean inputs
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap='Greys')
axs[0].set_title('Clean X')
# Add noise with scheduler
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print('Noisy X shape', noisy_xb.shape)
# Show noisy version (with and without clipping)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap='Greys')
axs[2].set_title('Noisy X');
在DDPM版本是从高斯分布上提取的噪声(来自均值0方差1的torch.randn),而不是上面的corrupt函数( 0-1之间的均匀分布,torch.rand)。
生成的随机数范围在[0, 1)
def show_images(x):
"""Given a batch of images x, make a grid and convert to PIL"""
x = x * 0.5 + 0.5 # Map from (-1, 1) back to (0, 1)
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
return grid_im
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
noise = torch.randn_like(xb) # randn不是rand!
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target
UNet2DModel 接受 x 和timestep。后者被转化为嵌入(embedding)并在多个地方输入到模型中。通过为模型提供有关噪声水平的信息,它可以更好地执行其任务。虽然可以在没有这种时间步条件的情况下训练模型,但在某些情况下它似乎确实有助于提高性能,并且大多数实现都包含它,至少在当前的文献中是这样。(玄学来了)
下面这是 PyTorch 中经典的优化迭代循环,在这里一批一批的送入数据然后通过优化器来一步步更新模型参数 - 在这个样例中我们使用学习率为 0.0004 的 AdamW 优化器。
loss.backward ()
与optimizer.step ()
来更新模型参数# Set the noise scheduler
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
losses = []
for epoch in range(30):
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"].to(device)
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
# 分为batsize份
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
# Add noise to the clean images according to the noise magnitude at each timestep
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
# Get the model prediction
noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
# Calculate the loss
loss = F.mse_loss(noise_pred, noise)
# Update the model parameters with the optimizer
if (epoch + 1) % 5 == 0:
loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
# 1. 建立一个管道:
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)
pipeline_output = image_pipe()
#2. 写一个取样循环
# Random starting point (8 random images):
sample = torch.randn(8, 3, 32, 32).to(device)
for i, t in enumerate(noise_scheduler.timesteps):
# Get model pred
with torch.no_grad():
residual = model(sample, t).sample
# Update sample with step
sample = noise_scheduler.step(residual, t, sample).prev_sample