随机微分方程的分数扩散模型(Score-Based Generative Modeling through Stochastic Differential Equations)
基于分数的扩散模型,是估计数据分布梯度的方法,可以在不需要对抗训练的基础上,生成与GAN一样高质量的图片。来源于文章:Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations." Internation Conference on Learning Representations, 2021
score-based diffusion是diffusion模型大火之后,又一个里程碑式的工作,将扩散模型和分数生成模型进行了统一。原始的扩散模型也有缺点,它的采样速度慢,通常需要数千个评估步骤才能抽取一个样本。而 score-based 的扩散模型可以在较短的时间内完成采样。
网络上有很多关于score-based diffusion原理介绍,应用案例等,还有文章解读,大家可以参考。但是,提供代码简介的很少,为此这里提供了score-based diffusion 模型的简单的可运行的代码示例。
导入相关模块
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm
其实并没有投影层的说法,这里是为了描述将时间t (time step),随机初始化采样权重,然后使用[sin(2πωt);cos(2πωt)]生成相应的高斯随机特征向量的过程。注意,里面的参数是不可训练的。
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# 在初始化期间随机采样权重。 这些权重是固定的
# 在优化期间并且不可训练
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
将时间t嵌入的投影层的出现,是因为score-based的扩散模型和正常的扩散模型的训练过程不一样。score-based的扩散模型在训练过程中,神经网络接受带有随机噪音的 x ,然后随机的时间信息 t 添加x中,然后利用x 和 t 作为输入,计算模型损失。
维度转换全连接层:
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
def __init__(self, input_dim, output_dim):
super().__init__()
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.dense(x)[..., None, None]
(time-dependent score-based model) 时间依赖,打分相关的Unet模型,froward函数中,输入除了x,还有时间t. 时间t经过GaussianFourierProjection嵌入后融合到模型中,然后输出marginal_prob_std正则化的结果。
class ScoreNet(nn.Module):
"""初始化一个依赖时间的基于分数的Unet网络."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
""".
Args:
marginal_prob_std: 输入时间 t 并给出扰动核的标准差的函数 p_{0t}(x(t) | x(0)).
channels: 各分辨率特征图的通道数.
embed_dim: 高斯随机特征嵌入的维数,与1.1中GaussianFourierProjection相同.
"""
super().__init__()
# 时间t的高斯随机特征嵌入层
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# 分辨率增加的解码层
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
# Swish 激活函数
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t):
# 0
embed = self.act(self.embed(t))
# Encoding path
h1 = self.conv1(x)
## 合并来自 t 的信息
h1 += self.dense1(embed)
## 组标准化
h1 = self.gnorm1(h1)
h1 = self.act(h1)
h2 = self.conv2(h1)
h2 += self.dense2(embed)
h2 = self.gnorm2(h2)
h2 = self.act(h2)
h3 = self.conv3(h2)
h3 += self.dense3(embed)
h3 = self.gnorm3(h3)
h3 = self.act(h3)
h4 = self.conv4(h3)
h4 += self.dense4(embed)
h4 = self.gnorm4(h4)
h4 = self.act(h4)
# Decoding path
h = self.tconv4(h4)
## 从编码路径跳过连接
h += self.dense5(embed)
h = self.tgnorm4(h)
h = self.act(h)
h = self.tconv3(torch.cat([h, h3], dim=1))
h += self.dense6(embed)
h = self.tgnorm3(h)
h = self.act(h)
h = self.tconv2(torch.cat([h, h2], dim=1))
h += self.dense7(embed)
h = self.tgnorm2(h)
h = self.act(h)
h = self.tconv1(torch.cat([h, h1], dim=1))
# Normalize output 正则化输出
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
SDE用于将P_0扰动到P_T, 其中,包含两个重要函数:之前提到的marginal_prob_std和扩散系数diffusion_coeff marginal_prob_std,计算 p_{0t}(x(t) | x(0)) 的平均值和标准差; diffusion_coeff,计算SDE的扩散系数.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
def marginal_prob_std(t, sigma):
"""计算p_{0t}(x(t) | x(0))的平均值和标准差.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
标准差.
"""
t = torch.tensor(t, device=device)
return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))
def diffusion_coeff(t, sigma):
"""计算SDE的扩散系数.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
扩散系数向量.
"""
return torch.tensor(sigma**t, device=device)
sigma = 25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
损失函数是一个复杂的公式,但是具体形式固定. 代码如下:
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
"""The loss function for training score-based generative models.
Args:
model: 时间依赖,基于分数的 PyTorch model.
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
eps: A tolerance value for numerical stability.
"""
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
return loss
与正常的训练模型相似,调用模型,建立优化器,损失反向等;代码如下:
score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
n_epochs = 50#@param {'type':'integer'}
## size of a mini-batch
batch_size = 32 #@param {'type':'integer'}
## learning rate
lr=1e-4 #@param {'type':'number'}
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.notebook.trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in data_loader:
x = x.to(device)
loss = loss_fn(score_model, x, marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Print the averaged training loss so far.
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Update the checkpoint after each epoch of training.
torch.save(score_model.state_dict(), 'ckpt.pth')
训练过程输出如下:
score-based diffusion模型有多种求解器,
欧拉-丸山采样方法属于数值SDE求解的方法,是基于神经网络预测的分数,利用逆时的SDE数值解,进行采样。
## 采样步数
num_steps = 500 #@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
num_steps=num_steps,
device='cuda',
eps=1e-3):
"""使用 Euler-Maruyama 求解器从基于分数的模型生成样本.
Args:
score_model: 时间依赖,基于分数的 PyTorch model.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
batch_size: 批次大小.
num_steps: 采样步数, 等价于相当于离散时间步数.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: 数值稳定性的最小时间步长.
Returns:
采样样本.
"""
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm.notebook.tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
# Do not include any noise in the last sampling step.
return mean_x
预测校正器采样器结合了逆时 SDE 的数值求解器和 Langevin MCMC 方法。 具体来说,我们首先应用数值 SDE 求解器的一个步骤从 xt 获得 xt−Δt,这称为“预测器”步骤。 接下来,我们应用 Langevin MCMC 的几个步骤来细化 xt ,使得 xt 成为 pt−Δt(x) 的更准确的样本。 这是“校正器”步骤,因为 MCMC 有助于减少数值 SDE 求解器的误差。
signal_to_noise_ratio = 0.16 #@param {'type':'number'}
## The number of sampling steps.
num_steps = 500#@param {'type':'integer'}
def pc_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
num_steps=num_steps,
snr=signal_to_noise_ratio,
device='cuda',
eps=1e-3):
"""
使用预测-校正方法从基于分数的模型生成样本.
Args:
score_model: 时间依赖,基于分数的 PyTorch model.
marginal_prob_std: A function that gives the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient
of the SDE.
batch_size: 批次大小.
num_steps: 采样步数, 等价于相当于离散时间步数.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: 数值稳定性的最小时间步长.
Returns:
采样样本.
"""
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
time_steps = np.linspace(1., eps, num_steps)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm.notebook.tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
# 检验器 step (Langevin MCMC)
grad = score_model(x, batch_time_step)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = np.sqrt(np.prod(x.shape[1:]))
langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
# 预测器 step (Euler-Maruyama)
g = diffusion_coeff(batch_time_step)
x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)
# The last step does not include any noise
return x_mean
每一个SDE都对应着一个ODE,通过逆时间方向求解此 ODE, 我们可以从与求解逆时间 SDE 相同的分布中进行采样。 我们将此 ODE 称为概率流 ODE。 这可以使用 scipy 等软件包提供的许多黑盒 ODE 求解器来完成。
from scipy import integrate
## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
atol=error_tolerance,
rtol=error_tolerance,
device='cuda',
z=None,
eps=1e-3):
"""Generate samples from score-based models with black-box ODE solvers.
Args:
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that returns the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
batch_size: The number of samplers to generate by calling this function once.
atol: Tolerance of absolute errors.
rtol: Tolerance of relative errors.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
z: The latent code that governs the final sample. If None, we start from p_1;
otherwise, we start from the given z.
eps: The smallest time step for numerical stability.
"""
t = torch.ones(batch_size, device=device)
# Create the latent code
if z is None:
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
else:
init_x = z
shape = init_x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper of the score-based model for use by the ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for use by the ODE solver."""
time_steps = np.ones((shape[0],)) * t
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
# Run the black-box ODE solver.
res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
print(f"Number of function evaluations: {res.nfev}")
x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
return x
from torchvision.utils import make_grid
## Load the pre-trained checkpoint from disk.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)
sample_batch_size = 64 #@param {'type':'integer'}
# 采样器配置
sampler = ode_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}
## Generate samples using the specified sampler.
samples = sampler(score_model,
marginal_prob_std_fn,
diffusion_coeff_fn,
sample_batch_size,
device=device)
## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
输出结果如下:
大家可以多试试其他的采样器,看看不同采样器输出结果的区别。
概率流 ODE 公式的副产品是似然计算。
def prior_likelihood(z, sigma):
"""The likelihood of a Gaussian distribution with mean zero and
standard deviation sigma."""
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * torch.log(2*np.pi*sigma**2) - torch.sum(z**2, dim=(1,2,3)) / (2 * sigma**2)
def ode_likelihood(x,
score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
device='cuda',
eps=1e-5):
"""Compute the likelihood with probability flow ODE.
Args:
x: Input data.
score_model: A PyTorch model representing the score-based model.
marginal_prob_std: A function that gives the standard deviation of the
perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the
forward SDE.
batch_size: The batch size. Equals to the leading dimension of `x`.
device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
eps: A `float` number. The smallest time step for numerical stability.
Returns:
z: The latent code for `x`.
bpd: The log-likelihoods in bits/dim.
"""
# Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
epsilon = torch.randn_like(x)
def divergence_eval(sample, time_steps, epsilon):
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
with torch.enable_grad():
sample.requires_grad_(True)
score_e = torch.sum(score_model(sample, time_steps) * epsilon)
grad_score_e = torch.autograd.grad(score_e, sample)[0]
return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))
shape = x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the score-based model for the black-box ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def divergence_eval_wrapper(sample, time_steps):
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
with torch.no_grad():
# Obtain x(t) by solving the probability flow ODE.
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
# Compute likelihood.
div = divergence_eval(sample, time_steps, epsilon)
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for the black-box solver."""
time_steps = np.ones((shape[0],)) * t
sample = x[:-shape[0]]
logp = x[-shape[0]:]
g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
sample_grad = -0.5 * g**2 * score_eval_wrapper(sample, time_steps)
logp_grad = -0.5 * g**2 * divergence_eval_wrapper(sample, time_steps)
return np.concatenate([sample_grad, logp_grad], axis=0)
init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
# Black-box ODE solver
res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
zp = torch.tensor(res.y[:, -1], device=device)
z = zp[:-shape[0]].reshape(shape)
delta_logp = zp[-shape[0]:].reshape(shape[0])
sigma_max = marginal_prob_std(1.)
prior_logp = prior_likelihood(z, sigma_max)
bpd = -(prior_logp + delta_logp) / np.log(2)
N = np.prod(shape[1:])
bpd = bpd / N + 8.
return z, bpd
计算数据集的似然率:
batch_size = 32 #@param {'type':'integer'}
dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)
all_bpds = 0.
all_items = 0
try:
tqdm_data = tqdm.notebook.tqdm(data_loader)
for x, _ in tqdm_data:
x = x.to(device)
# uniform dequantization
x = (x * 255. + torch.rand_like(x)) / 256.
_, bpd = ode_likelihood(x, score_model, marginal_prob_std_fn,
diffusion_coeff_fn,
x.shape[0], device=device, eps=1e-5)
all_bpds += bpd.sum()
all_items += bpd.shape[0]
tqdm_data.set_description("Average bits/dim: {:5f}".format(all_bpds / all_items))
except KeyboardInterrupt:
# Remove the error message when interuptted by keyboard or GUI.
pass
(1)随机微分方程的分数扩散模型需要一个时间依赖的基于分数的神经网络;
(2)时间依赖的基于分数的神经网络forward函数,输入是扰动后的x, t,输出是分数,这一点与传统的扩散模型不同; 传统的扩散模型神经网络输入是扰动后的x,然后输出不带噪音的x或者噪音;
(3)在时间依赖的基于分数的神经网络forward函数中,需要几个重要的支持函数: GaussianFourierProjection:输入时间t,输出高斯随机特征向量,使t可以被整合到x中; marginal_prob_std:计算时间步t的方差,用于神经网络输出分数的归一化; *意味着基于分数的扩散模型,需要重新写模型架构
(4)基于分数的扩散模型的损失函数非常简单,为: loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
(5)基于神经网络输出的分数采样有多种方法,分别为: 欧拉-丸山采样器(Euler-Maruyama sampler),预测-检验采样器,ODE数值求解器;
(6)每一种采样器都需要先设置好SDE,里面一个重要函数是diffusion_coeff_fn,用于计算SDE的扩散系数
(7)每一种采样器都有固定的形式直接使用就好;
写在最后,关于score-based的diffusion模型的原理,我这里并没有介绍。因为现在又很多博客或者公众号,视频都有详细的介绍,包括详细的共识推导,另外,我不是数学专业,里面很多的数学原理也是半知半解的,就不耽误大家了。大家可以查看相关资料。
关于原理,大家有想法或者想通俗的了解,可以留言,可以考虑出一个,专门说一下。