随机微分方程的分数扩散模型(Score-Based Generative Modeling through Stochastic Differential Equations)
基于分数的扩散模型,是估计数据分布梯度的方法,可以在不需要对抗训练的基础上,生成与GAN一样高质量的图片。来源于文章:Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations." Internation Conference on Learning Representations, 2021
score-based diffusion是diffusion模型大火之后,又一个里程碑式的工作,将扩散模型和分数生成模型进行了统一。原始的扩散模型也有缺点,它的采样速度慢,通常需要数千个评估步骤才能抽取一个样本。而 score-based 的扩散模型可以在较短的时间内完成采样。
网络上有很多关于score-based diffusion原理介绍,应用案例等,还有文章解读,大家可以参考。但是,提供代码简介的很少,为此这里提供了score-based diffusion 模型的简单的可运行的代码示例。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm
其实并没有投影层的说法,这里是为了描述将时间t (time step),随机初始化采样权重,然后使用[sin(2πωt);cos(2πωt)]生成相应的高斯随机特征向量的过程。注意,里面的参数是不可训练的。
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
# 在初始化期间随机采样权重。 这些权重是固定的
# 在优化期间并且不可训练
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
将时间t嵌入的投影层的出现,是因为score-based的扩散模型和正常的扩散模型的训练过程不一样。score-based的扩散模型在训练过程中,神经网络接受带有随机噪音的 x ,然后随机的时间信息 t 添加x中,然后利用x 和 t 作为输入,计算模型损失。
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
def __init__(self, input_dim, output_dim):
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.dense(x)[..., None, None]
(time-dependent score-based model) 时间依赖,打分相关的Unet模型,froward函数中,输入除了x,还有时间t. 时间t经过GaussianFourierProjection嵌入后融合到模型中,然后输出marginal_prob_std正则化的结果。
class ScoreNet(nn.Module):
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
marginal_prob_std: 输入时间 t 并给出扰动核的标准差的函数 p_{0t}(x(t) | x(0)).
channels: 各分辨率特征图的通道数.
embed_dim: 高斯随机特征嵌入的维数,与1.1中GaussianFourierProjection相同.
# 时间t的高斯随机特征嵌入层
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# 分辨率增加的解码层
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
# Swish 激活函数
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t):
# 0
embed = self.act(self.embed(t))
# Encoding path
h1 = self.conv1(x)
## 合并来自 t 的信息
h1 += self.dense1(embed)
## 组标准化
h1 = self.gnorm1(h1)
h1 = self.act(h1)
h2 = self.conv2(h1)
h2 += self.dense2(embed)
h2 = self.gnorm2(h2)
h2 = self.act(h2)
h3 = self.conv3(h2)
h3 += self.dense3(embed)
h3 = self.gnorm3(h3)
h3 = self.act(h3)
h4 = self.conv4(h3)
h4 += self.dense4(embed)
h4 = self.gnorm4(h4)
h4 = self.act(h4)
# Decoding path
h = self.tconv4(h4)
## 从编码路径跳过连接
h += self.dense5(embed)
h = self.tgnorm4(h)
h = self.act(h)
h = self.tconv3(torch.cat([h, h3], dim=1))
h += self.dense6(embed)
h = self.tgnorm3(h)
h = self.act(h)
h = self.tconv2(torch.cat([h, h2], dim=1))
h += self.dense7(embed)
h = self.tgnorm2(h)
h = self.act(h)
h = self.tconv1(torch.cat([h, h1], dim=1))
# Normalize output 正则化输出
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
SDE用于将P_0扰动到P_T, 其中,包含两个重要函数:之前提到的marginal_prob_std和扩散系数diffusion_coeff marginal_prob_std,计算 p_{0t}(x(t) | x(0)) 的平均值和标准差; diffusion_coeff,计算SDE的扩散系数.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
def marginal_prob_std(t, sigma):
"""计算p_{0t}(x(t) | x(0))的平均值和标准差.
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
t = torch.tensor(t, device=device)
return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))
def diffusion_coeff(t, sigma):
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
return torch.tensor(sigma**t, device=device)
sigma = 25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
损失函数是一个复杂的公式,但是具体形式固定. 代码如下:
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
"""The loss function for training score-based generative models.
model: 时间依赖,基于分数的 PyTorch model.
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
eps: A tolerance value for numerical stability.
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
return loss
score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
n_epochs = 50#@param {'type':'integer'}
## size of a mini-batch
batch_size = 32 #@param {'type':'integer'}
## learning rate
lr=1e-4 #@param {'type':'number'}
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.notebook.trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in data_loader:
x = x.to(device)
loss = loss_fn(score_model, x, marginal_prob_std_fn)
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Print the averaged training loss so far.
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Update the checkpoint after each epoch of training.
torch.save(score_model.state_dict(), 'ckpt.pth')
score-based diffusion模型有多种求解器,
## 采样步数
num_steps = 500 #@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model,
"""使用 Euler-Maruyama 求解器从基于分数的模型生成样本.
score_model: 时间依赖,基于分数的 PyTorch model.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
batch_size: 批次大小.
num_steps: 采样步数, 等价于相当于离散时间步数.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: 数值稳定性的最小时间步长.
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm.notebook.tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
# Do not include any noise in the last sampling step.
return mean_x
预测校正器采样器结合了逆时 SDE 的数值求解器和 Langevin MCMC 方法。 具体来说,我们首先应用数值 SDE 求解器的一个步骤从 xt 获得 xt−Δt,这称为“预测器”步骤。 接下来,我们应用 Langevin MCMC 的几个步骤来细化 xt ,使得 xt 成为 pt−Δt(x) 的更准确的样本。 这是“校正器”步骤,因为 MCMC 有助于减少数值 SDE 求解器的误差。
signal_to_noise_ratio = 0.16 #@param {'type':'number'}
## The number of sampling steps.
num_steps = 500#@param {'type':'integer'}
def pc_sampler(score_model,
score_model: 时间依赖,基于分数的 PyTorch model.
marginal_prob_std: A function that gives the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient
of the SDE.
batch_size: 批次大小.
num_steps: 采样步数, 等价于相当于离散时间步数.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: 数值稳定性的最小时间步长.
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
time_steps = np.linspace(1., eps, num_steps)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm.notebook.tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
# 检验器 step (Langevin MCMC)
grad = score_model(x, batch_time_step)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = np.sqrt(np.prod(x.shape[1:]))
langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
# 预测器 step (Euler-Maruyama)
g = diffusion_coeff(batch_time_step)
x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)
# The last step does not include any noise
return x_mean
每一个SDE都对应着一个ODE,通过逆时间方向求解此 ODE, 我们可以从与求解逆时间 SDE 相同的分布中进行采样。 我们将此 ODE 称为概率流 ODE。 这可以使用 scipy 等软件包提供的许多黑盒 ODE 求解器来完成。
from scipy import integrate
## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
"""Generate samples from score-based models with black-box ODE solvers.
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that returns the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
batch_size: The number of samplers to generate by calling this function once.
atol: Tolerance of absolute errors.
rtol: Tolerance of relative errors.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
z: The latent code that governs the final sample. If None, we start from p_1;
otherwise, we start from the given z.
eps: The smallest time step for numerical stability.
t = torch.ones(batch_size, device=device)
# Create the latent code
if z is None:
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
init_x = z
shape = init_x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper of the score-based model for use by the ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for use by the ODE solver."""
time_steps = np.ones((shape[0],)) * t
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
# Run the black-box ODE solver.
res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
print(f"Number of function evaluations: {res.nfev}")
x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
return x
from torchvision.utils import make_grid
## Load the pre-trained checkpoint from disk.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt.pth', map_location=device)
sample_batch_size = 64 #@param {'type':'integer'}
# 采样器配置
sampler = ode_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}
## Generate samples using the specified sampler.
samples = sampler(score_model,
## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
概率流 ODE 公式的副产品是似然计算。
def prior_likelihood(z, sigma):
"""The likelihood of a Gaussian distribution with mean zero and
standard deviation sigma."""
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * torch.log(2*np.pi*sigma**2) - torch.sum(z**2, dim=(1,2,3)) / (2 * sigma**2)
def ode_likelihood(x,
"""Compute the likelihood with probability flow ODE.
x: Input data.
score_model: A PyTorch model representing the score-based model.
marginal_prob_std: A function that gives the standard deviation of the
perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the
forward SDE.
batch_size: The batch size. Equals to the leading dimension of `x`.
device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
eps: A `float` number. The smallest time step for numerical stability.
z: The latent code for `x`.
bpd: The log-likelihoods in bits/dim.
# Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
epsilon = torch.randn_like(x)
def divergence_eval(sample, time_steps, epsilon):
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
with torch.enable_grad():
score_e = torch.sum(score_model(sample, time_steps) * epsilon)
grad_score_e = torch.autograd.grad(score_e, sample)[0]
return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))
shape = x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the score-based model for the black-box ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def divergence_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
with torch.no_grad():
# Obtain x(t) by solving the probability flow ODE.
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
# Compute likelihood.
div = divergence_eval(sample, time_steps, epsilon)
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for the black-box solver."""
time_steps = np.ones((shape[0],)) * t
sample = x[:-shape[0]]
logp = x[-shape[0]:]
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
sample_grad = -0.5 * g**2 * score_eval_wrapper(sample, time_steps)
logp_grad = -0.5 * g**2 * divergence_eval_wrapper(sample, time_steps)
return np.concatenate([sample_grad, logp_grad], axis=0)
init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
# Black-box ODE solver
res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
zp = torch.tensor(res.y[:, -1], device=device)
z = zp[:-shape[0]].reshape(shape)
delta_logp = zp[-shape[0]:].reshape(shape[0])
sigma_max = marginal_prob_std(1.)
prior_logp = prior_likelihood(z, sigma_max)
bpd = -(prior_logp + delta_logp) / np.log(2)
N = np.prod(shape[1:])
bpd = bpd / N + 8.
return z, bpd
batch_size = 32 #@param {'type':'integer'}
dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
ckpt = torch.load('ckpt.pth', map_location=device)
all_bpds = 0.
all_items = 0
tqdm_data = tqdm.notebook.tqdm(data_loader)
for x, _ in tqdm_data:
x = x.to(device)
# uniform dequantization
x = (x * 255. + torch.rand_like(x)) / 256.
_, bpd = ode_likelihood(x, score_model, marginal_prob_std_fn,
x.shape[0], device=device, eps=1e-5)
all_bpds += bpd.sum()
all_items += bpd.shape[0]
tqdm_data.set_description("Average bits/dim: {:5f}".format(all_bpds / all_items))
except KeyboardInterrupt:
# Remove the error message when interuptted by keyboard or GUI.
(2)时间依赖的基于分数的神经网络forward函数,输入是扰动后的x, t,输出是分数,这一点与传统的扩散模型不同; 传统的扩散模型神经网络输入是扰动后的x,然后输出不带噪音的x或者噪音;
(3)在时间依赖的基于分数的神经网络forward函数中,需要几个重要的支持函数: GaussianFourierProjection:输入时间t,输出高斯随机特征向量,使t可以被整合到x中; marginal_prob_std:计算时间步t的方差,用于神经网络输出分数的归一化; *意味着基于分数的扩散模型,需要重新写模型架构
(4)基于分数的扩散模型的损失函数非常简单,为: loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
(5)基于神经网络输出的分数采样有多种方法,分别为: 欧拉-丸山采样器(Euler-Maruyama sampler),预测-检验采样器,ODE数值求解器;