# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)
This is a [U-Net](../../unet/index.html) based model to predict noise
$\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$.
U-Net is a gets it's name from the U shape in the model diagram.
It processes a given image by progressively lowering (halving) the feature map resolution and then
increasing the resolution.
There are pass-through connection at each resolution.

This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)
and also adds time-step embeddings $t$.
import math
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
from labml_helpers.module import Module
class Swish(Module):
### Swish actiavation function
$$x \cdot \sigma(x)$$
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
### Embeddings for $t$
def __init__(self, n_channels: int):
* `n_channels` is the number of dimensions in the embedding
self.n_channels = n_channels
# First linear layer
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
# Activation
self.act = Swish()
# Second linear layer
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
# Create sinusoidal position embeddings
# [same as those from the transformer](../../transformers/positional_encoding.html)
# \begin{align}
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
# \end{align}
# where $d$ is `half_dim`
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
# Transform with the MLP
#通过lin1维度转化为了64*4 也就是n_channels*4
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)
return emb
class ResidualBlock(Module):
### Residual block
A residual block has two convolution layers with group normalization.
Each resolution is processed with two residual blocks.
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int = 32, dropout: float = 0.1):
* `in_channels` is the number of input channels
* `out_channels` is the number of input channels
* `time_channels` is the number channels in the time step ($t$) embeddings
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
* `dropout` is the dropout rate
# Group normalization and the first convolution layer
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# Group normalization and the second convolution layer
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# If the number of input channels is not equal to the number of output channels we have to
# project the shortcut connection
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.shortcut = nn.Identity()
# Linear layer for time embeddings
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor):
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
# First convolution layer
h = self.conv1(self.act1(self.norm1(x)))
# Add time embeddings
h += self.time_emb(self.time_act(t))[:, :, None, None]
# Second convolution layer
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
# Add the shortcut connection and return
return h + self.shortcut(x)
class AttentionBlock(Module):
### Attention block
This is similar to [transformer multi-head attention](../../transformers/mha.html).
# 定义类的初始化方法
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
* `n_channels` is the number of channels in the input
* `n_heads` is the number of heads in multi-head attention
* `d_k` is the number of dimensions in each head
* `n_groups` is the number of groups for group normalization
# 调用父类的初始化方法
# 如果没有指定`d_k`,则将其设置为与输入通道数相同
if d_k is None:
d_k = n_channels
# 创建一个组归一化层,用于对输入进行归一化处理
self.norm = nn.GroupNorm(n_groups, n_channels)
# 创建一个线性层,用于生成查询、键和值
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
# 创建一个线性层,用于最后的变换
self.output = nn.Linear(n_heads * d_k, n_channels)
# 计算点积注意力的缩放因子
self.scale = d_k ** -0.5
self.n_heads = n_heads
self.d_k = d_k
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
# Get shape
batch_size, n_channels, height, width = x.shape
# Change `x` to shape `[batch_size, seq, n_channels]`
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`、
#q, k, v = torch.chunk(qkv, 3, dim=-1)
q, k, v = torch.chunk(qkv, 3, dim=-1)
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = attn.softmax(dim=2)
# Multiply by values
res = torch.einsum('bijh,bjhd->bihd', attn, v)
# Reshape to `[batch_size, seq, n_heads * d_k]`
res = res.view(batch_size, -1, self.n_heads * self.d_k)
# Transform to `[batch_size, seq, n_channels]`
res = self.output(res)
# Add skip connection
res += x
# Change to shape `[batch_size, in_channels, height, width]`
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
return res
class DownBlock(Module):
### Down block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
self.res = ResidualBlock(in_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class UpBlock(Module):
### Up block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
# The input has `in_channels + out_channels` because we concatenate the output of the same resolution
# from the first half of the U-Net
self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class MiddleBlock(Module):
### Middle block
It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
This block is applied at the lowest resolution of the U-Net.
def __init__(self, n_channels: int, time_channels: int):
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
self.attn = AttentionBlock(n_channels)
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res1(x, t)
x = self.attn(x)
x = self.res2(x, t)
return x
class Upsample(nn.Module):
### Scale up the feature map by $2 \times$
def __init__(self, n_channels):
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class Downsample(nn.Module):
### Scale down the feature map by $\frac{1}{2} \times$
def __init__(self, n_channels):
#(3, 3): 表示卷积核的大小在高度和宽度上都是3。
#(2,2): 表示卷积的步长在高度和宽度上都是2。
#(1,1): 表示输入数据的填充大小在高度和宽度上都是1
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class UNet(Module):
## U-Net
#Union[Tuple[int, ...], List[int]]表示ch_mults参数可以是一个元组或列表,其中元素的类型为整数。
def __init__(self, image_channels: int = 3, n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2):
* `image_channels` is the number of channels in the image. $3$ for RGB.
* `n_channels` is number of channels in the initial feature map that we transform the image into
* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
* `is_attn` is a list of booleans that indicate whether to use attention at each resolution
* `n_blocks` is the number of `UpDownBlocks` at each resolution
# Number of resolutions
n_resolutions = len(ch_mults)
# Project image into feature map
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
# Time embedding layer. Time embedding has `n_channels * 4` channels
self.time_emb = TimeEmbedding(n_channels * 4)
# #### First half of U-Net - decreasing resolution
down = []
# Number of channels
out_channels = in_channels = n_channels
# For each resolution
for i in range(n_resolutions):
# Number of output channels at this resolution
out_channels = in_channels * ch_mults[i]
# Add `n_blocks`
#每层2快,一块进行残差引入位置t的嵌入信息,以及在某些时候引入注意力信息 ,另一块进行下采样
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# Down sample at all resolutions except the last
if i < n_resolutions - 1:
# 当你把一系列模块放入nn.ModuleList后,这些模块就会被正确地注册为网络的一部分,
# 从而能够正确地进行前向传播和反向传播
# Combine the set of modules
self.down = nn.ModuleList(down)
# Middle block
self.middle = MiddleBlock(out_channels, n_channels * 4, )
# #### Second half of U-Net - increasing resolution
up = []
# Number of channels
in_channels = out_channels
# For each resolution
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
# Final block to reduce the number of channels
out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# Up sample at all resolutions except last
if i > 0:
# Combine the set of modules
self.up = nn.ModuleList(up)
# Final normalization and convolution layer
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size]`
# Get time-step embeddings
t = self.time_emb(t)
# Get image projection
x = self.image_proj(x)
# `h` will store outputs at each resolution for skip connection
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
# Middle (bottom)
x = self.middle(x, t)
# Second half of U-Net
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
# Get the skip connection from first half of U-Net and concatenate
# 用来让我们上采样生成的图片更加的接近我们的原始图像,而不是随机生成的
s = h.pop()
#(batch_size, n_channels, height, width)
#第2个通道数是n_channels ,所以这里是把维数进行了一个扩充
x = torch.cat((x, s), dim=1)
x = m(x, t)
# Final normalization and convolution
return self.final(self.act(self.norm(x)))
# Denoising Diffusion Probabilistic Models (DDPM)
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
[Denoising Diffusion Probabilistic Models](https://papers.labml.ai/paper/2006.11239).
In simple terms, we get an image from data and add noise step by step.
Then We train a model to predict that noise at each step and use the model to
generate images.
The following definitions and derivations show how this works.
For details please refer to [the paper](https://papers.labml.ai/paper/2006.11239).
## Forward Process
The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1- \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
where $\beta_1, \dots, \beta_T$ is the variance schedule.
We can sample $x_t$ at any timestep $t$ with,
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
## Reverse Process
The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
for $T$ time steps.
\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\textcolor{lightgreen}{\mu_\theta}(x_t, t), \textcolor{lightgreen}{\Sigma_\theta}(x_t, t)\big) \\
\textcolor{lightgreen}{p_\theta}(x_{0:T}) &= \textcolor{lightgreen}{p_\theta}(x_T) \prod_{t = 1}^{T} \textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) \\
\textcolor{lightgreen}{p_\theta}(x_0) &= \int \textcolor{lightgreen}{p_\theta}(x_{0:T}) dx_{1:T}
$\textcolor{lightgreen}\theta$ are the parameters we train.
## Loss
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
\mathbb{E}[-\log \textcolor{lightgreen}{p_\theta}(x_0)]
&\le \mathbb{E}_q [ -\log \frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
The loss can be rewritten as follows.
&= \mathbb{E}_q [ -\log \frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
&= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
&= \mathbb{E}_q [
-\log \frac{p(x_T)}{q(x_T|x_0)}
-\sum_{t=2}^T \log \frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
-\log \textcolor{lightgreen}{p_\theta}(x_0|x_1)] \\
&= \mathbb{E}_q [
D_{KL}(q(x_T|x_0) \Vert p(x_T))
+\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t))
-\log \textcolor{lightgreen}{p_\theta}(x_0|x_1)]
$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.
### Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t))$
The forward process posterior conditioned by $x_0$ is,
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t
The paper sets $\textcolor{lightgreen}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
$\beta_t$ or $\tilde\beta_t$.
$$\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big)$$
For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) - \sqrt{1-\bar\alpha_t}\epsilon\Big)
This gives,
&= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)) \\
&= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
\Big \Vert \tilde\mu(x_t, x_0) - \textcolor{lightgreen}{\mu_\theta}(x_t, t) \Big \Vert^2 \Bigg] \\
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
\bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
\Big) - \textcolor{lightgreen}{\mu_\theta}(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
Re-parameterizing with a model to predict noise
\textcolor{lightgreen}{\mu_\theta}(x_t, t) &= \tilde\mu \bigg(x_t,
\frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
\sqrt{1-\bar\alpha_t}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big) \bigg) \\
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)
where $\epsilon_\theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.
This gives,
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
\epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\Big\Vert^2 \Bigg]
That is, we are training to predict the noise.
### Simplified loss
$$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
This minimizes $-\log \textcolor{lightgreen}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during
Here is the [UNet model](unet.html) that gives $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ and
[training code](experiment.html).
[This file](evaluate.html) can generate samples and interpolations from a trained model.
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
from labml_nn.diffusion.ddpm.utils import gather
class DenoiseDiffusion:
## Denoise Diffusion
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
* `eps_model` is $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
* `n_steps` is $t$
* `device` is the device to place constants on
self.eps_model = eps_model
# Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
# $\alpha_t = 1 - \beta_t$
self.alpha = 1. - self.beta
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
# $T$
self.n_steps = n_steps
# $\sigma^2 = \beta$
self.sigma2 = self.beta
##->箭头右边的 表示函数返回类型的注释
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#### Get $q(x_t|x_0)$ distribution
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
# [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
#首先调用了gather(self.alpha_bar, t)函数抽取了alpha_bar中的第t+1个元素,并将返回的结果取平方根(即 ** 0.5),然后再乘以 x02。
mean = gather(self.alpha_bar, t) ** 0.5 * x0
# $(1-\bar\alpha_t) \mathbf{I}$
var = 1 - gather(self.alpha_bar, t)
return mean, var
# 前向加噪
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
#### Sample from $q(x_t|x_0)$
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
if eps is None:
eps = torch.randn_like(x0)
# get $q(x_t|x_0)$
mean, var = self.q_xt_x0(x0, t)
# Sample from $q(x_t|x_0)$
return mean + (var ** 0.5) * eps
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
\textcolor{lightgreen}{\mu_\theta}(x_t, t)
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
eps_theta = self.eps_model(xt, t)
# [gather](utils.html) $\bar\alpha_t$
alpha_bar = gather(self.alpha_bar, t)
# $\alpha_t$
alpha = gather(self.alpha, t)
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
# $\sigma^2$
var = gather(self.sigma2, t)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
eps = torch.randn(xt.shape, device=xt.device)
# Sample
return mean + (var ** .5) * eps
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
#### Simplified Loss
$$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
# Get batch size
batch_size = x0.shape[0]
# Get random $t$ for each sample in the batch
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
if noise is None:
noise = torch.randn_like(x0)
# Sample $x_t$ for $q(x_t|x_0)$
xt = self.q_sample(x0, t, eps=noise)
# Get $\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
#利用加噪t步后的噪声预测 来预测图像在x0时刻加入的噪声ε
eps_theta = self.eps_model(xt, t)
# MSE loss
return F.mse_loss(noise, eps_theta)
title: Denoising Diffusion Probabilistic Models (DDPM) training
summary: >
Training code for
Denoising Diffusion Probabilistic Model.
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
Save the images inside [`data/celebA` folder](#dataset_path).
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
from typing import List
import torch
import torch.utils.data
import torchvision
from PIL import Image
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option
from labml_helpers.device import DeviceConfigs
from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet
class Configs(BaseConfigs):
## Configurations
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# picks up an available CUDA device or defaults to CPU.
##类型注解 变量名称: 变量类型=变量的值 (用来说明变量的类型,便于后续查看,也可以先定义变量,不赋值)
device: torch.device = DeviceConfigs()
# U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
eps_model: UNet
# [DDPM algorithm](index.html)
diffusion: DenoiseDiffusion
# Number of channels in the image. $3$ for RGB.
image_channels: int = 3
# Image size
image_size: int = 32
# Number of channels in the initial feature map
n_channels: int = 64
# The list of channel numbers at each resolution.
# The number of channels is `channel_multipliers[i] * n_channels`
channel_multipliers: List[int] = [1, 2, 2, 4]
# The list of booleans that indicate whether to use attention at each resolution
is_attention: List[int] = [False, False, False, True]
# Number of time steps $T$
n_steps: int = 1_000
# Batch size
batch_size: int = 64
# Number of samples to generate
n_samples: int = 16
# Learning rate
learning_rate: float = 2e-5
# Number of training epochs
epochs: int = 1_000
# Dataset
dataset: torch.utils.data.Dataset
# Dataloader
data_loader: torch.utils.data.DataLoader
# Adam optimizer
optimizer: torch.optim.Adam
def init(self):
# Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
self.eps_model = UNet(
# Create [DDPM class](index.html)
self.diffusion = DenoiseDiffusion(
# Create dataloader
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
# Create optimizer
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
# Image logging
#tracker.set_image("sample", True)这行代码的作用是设置一个名为"sample"的图像跟踪器,并在控制台打印图像的统计信息。
tracker.set_image("sample", True)
def sample(self):
### Sample images
# 使用torch.no_grad()上下文管理器,表示接下来的计算不需要计算梯度,可以节省内存
with torch.no_grad():
# 从标准正态分布中随机生成一个张量x,形状为[self.n_samples, self.image_channels, self.image_size, self.image_size],
# 并将其放在self.device指定的设备上。这个张量x代表了一组噪声图像
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
# 对每一个时间步进行迭代,同时一步一步移除噪声
for t_ in monit.iterate('Sample', self.n_steps):
# 计算当前的时间步t,它是从n_steps递减到0的
t = self.n_steps - t_ - 1
# 从条件分布p(x_{t-1}|x_t)中采样新的x。这个条件分布由self.diffusion.p_sample给出,
# 它接收当前的x和时间步t作为输入,并返回新的x
#x.new_full((self.n_samples,), ...)这行代码的作用是创建一个与x具有相同数据类型和设备的新张量,形状为(self.n_samples,),所有元素被填充为指定的值1
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
# 使用tracker.save方法记录采样得到的图像x
tracker.save('sample', x)
def train(self):
# 然后计算通过xt和t 去预测加入的噪声的正太分布,然后更新unet模型的参数,让我们的模型能够更加准确地,
### Train
# Iterate through the dataset
# 遍历数据加载器中的所有数据,`monit.iterate`函数用于监控训练过程
for data in monit.iterate('Train', self.data_loader):
# 增加全局步数,用于跟踪训练的进度
# 将数据移动到设备(例如GPU)上,以便进行计算
data = data.to(self.device)
# 将优化器的梯度清零,这是因为PyTorch会累积梯度,所以在每次迭代开始时需要清零
# 计算损失,`self.diffusion.loss(data)`函数计算了模型在给定数据上的损失
loss = self.diffusion.loss(data)
# 计算梯度,`loss.backward()`函数通过反向传播算法计算了损失关于模型参数的梯度
# 进行一步优化,`self.optimizer.step()`函数根据计算出的梯度更新了模型的参数
# 跟踪损失,`tracker.save('loss', loss)`函数保存了当前的损失值,以便后续分析
tracker.save('loss', loss)
# 定义一个名为`run`的方法,它没有参数除了`self`,`self`代表类的实例
def run(self):
### Training loop
# 这是一个循环,它将运行`self.epochs`次。每次循环代表一个训练周期。
for _ in monit.loop(self.epochs):
# 调用`self.train()`方法来训练模型。这个方法可能包含了一次完整的训练过程,例如前向传播、计算损失、反向传播和优化步骤。
# 调用`self.sample()`方法来生成一些样本。这些样本可能用于检查模型的性能。
# 在控制台中添加一个新行,以便更好地显示训练进度
# 调用`experiment.save_checkpoint()`方法来保存模型的检查点。这样,你可以在以后的任何时间加载模型的状态,并从上次停止的地方继续训练。
class CelebADataset(torch.utils.data.Dataset):
### CelebA HQ dataset
# CelebA具有大量的多样性、数量和丰富的注释,包括10,177个身份、202,599张人脸图像、5个地标位置、每张图像40个二进制属性注释12。
# 这个数据集可以用于以下计算机视觉任务:人脸属性识别、人脸识别、人脸检测、地标(或面部部分)定位以及人脸编辑和合成12。
def __init__(self, image_size: int):
# CelebA images folder
# 'celebA'是CelebA数据集的文件夹名1。/操作符用于连接这两部分,得到CelebA数据集的完整文件夹路径。
folder = lab.get_data_path() / 'celebA'
# List of files
self._files = [p for p in folder.glob(f'**/*.jpg')]
# Transformations to resize the image and convert to tensor
self._transform = torchvision.transforms.Compose([
def __len__(self):
Size of the dataset
return len(self._files)
def __getitem__(self, index: int):
Get an image
img = Image.open(self._files[index])
return self._transform(img)
@option(Configs.dataset, 'CelebA')
def celeb_dataset(c: Configs):
Create CelebA dataset
return CelebADataset(c.image_size)
class MNISTDataset(torchvision.datasets.MNIST):
### MNIST dataset
def __init__(self, image_size):
transform = torchvision.transforms.Compose([
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
def __getitem__(self, item):
return super().__getitem__(item)[0]
@option(Configs.dataset, 'MNIST')
def mnist_dataset(c: Configs):
Create MNIST dataset
return MNISTDataset(c.image_size)
def main():
# Create experiment
experiment.create(name='diffuse', writers={'screen', 'labml'})
# Create configurations
configs = Configs()
# Set configurations. You can override the defaults by passing the values in the dictionary.
experiment.configs(configs, {
'dataset': 'CelebA', # 'MNIST'
'image_channels': 3, # 1,
'epochs': 100, # 5,
# Initialize
# Set models for saving and loading
experiment.add_pytorch_models({'eps_model': configs.eps_model})
# Start and run the training loop
with experiment.start():
if __name__ == '__main__':
