在完成IDDPM论文学习后,对github上的官方仓库进行学习,通过具体的代码理解算法实现过程中的一些细节;官方仓库代码基于pytorch实现,链接为https://github.com/openai/improved-diffusion。本笔记主要针对项目中采样部分代码进行注释解析,主要涉及仓库项目中的image_sample.py、respace.py文件。
在完成训练后,其实主要用到的就是Unet模型以随机采样的噪声逐步进行逆扩散过程进行采样;实际就是使用Unet模型预测一个分布的均值和方差,然后进行重参数操作得到对图像的预测。需要注意的是,在训练时是对图像的像素值进行归一化操作的,每个像素点的数值在区间[-1, 1]内,故模型采样出来的图像的数值也在这个区间内,需要通过先加一然后乘上127.5的方法将像素值还原为[0, 255]区间内。在实际采样时,会涉及到IDDPM论文提出的一个改善点,即采样速度的改善,主要解释见respace.py。
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import os
import numpy as np
import torch as th
import torch.distributed as dist
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
NUM_CLASSES,
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
def main():
args = create_argparser().parse_args()
dist_util.setup_dist() # 分布式
logger.configure()
logger.log("creating model and diffusion...")
# 初始化Unet模型和扩散过程对象
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys()))
# 模型加载参数
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu"))
model.to(dist_util.dev())
model.eval()
logger.log("sampling...")
all_images = []
all_labels = []
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
classes = th.randint(
low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()) # 随机生成图片label/类别信息
model_kwargs["y"] = classes # 如果存在,使用的地方是在Unet模型计算时会将提作为条件嵌入与时间嵌入叠加,作为条件信息指导模型生成
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop) # 选择具体采样函数,是否使用ddim方法
sample = sample_fn(
model,
(args.batch_size, 3, args.image_size, args.image_size), # 此为采样时图像的尺寸[batch_size, 3, image_size, image_size]
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,)
sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) # 将图片每个位置的数值转为0~255区间内,即还原为图片
sample = sample.permute(0, 2, 3, 1) # [batch_size, image_size, image_size, 3]
sample = sample.contiguous()
# 将多卡中的采样的样本图片和推定的label集合
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
arr = np.concatenate(all_images, axis=0)
arr = arr[: args.num_samples]
if args.class_cond:
label_arr = np.concatenate(all_labels, axis=0)
label_arr = label_arr[: args.num_samples]
if dist.get_rank() == 0:
shape_str = "x".join([str(x) for x in arr.shape])
out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
logger.log(f"saving to {out_path}")
if args.class_cond:
np.savez(out_path, arr, label_arr)
else:
np.savez(out_path, arr)
dist.barrier()
logger.log("sampling complete")
# 初始化所需参数
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
batch_size=16,
use_ddim=False,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
本文件中主要是定义了继承与GaussianDiffusion类的SpacedDiffusion类,还定义了一个space_timesteps函数,该函数可以基于给定的section_counts对原始时间步序列抽取一个采样子序列 S t S_t St。GaussianDiffusion就是可以使用一个时间步的更短的子序列 S t S_t St实现最终图片采样的操作,目的就是减少推理阶段模型逆扩散采样所需的步数,实现采样速改善,提高模型的使用效率。GaussianDiffusion类中会基于时间步的子序列 S t S_t St重新计算对应的 β S t \beta_{S_t} βSt,其会覆盖采样过程预测方差时所需的原始 β t \beta_t βt,最终影响图片采样。在模型训练时,script_util.py文件中初始化的扩散过程对象其实就是GaussianDiffusion类,但是训练过程还是使用正常的时间步序列如1000步进行训练,只需要使用space_timesteps函数生成时间子序列时保证与原始时间步序列一致就行了。
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
给定想从原始扩散过程的同等大小的部分中提取的时间段数量,从原始扩散过程中创建一个使用的时间段列表。
意思就是,如果原始扩散过程进行300步,而传入的section_counts为[10, 15, 20],则会将300步等分
为3个100步,第一个100步等比缩小为10步,第二个100步等比缩小为15步,第三个100步等比缩小为20步。
如果传入参数分别是1000和[1000],相当于时间步序列没有改变,因为是将1000步分成1个包含1000个区间的序列
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.原始的时间步
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.从原始时间步中得到的一个新的时间步序列
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"): # 使用ddim论文中固定的stride
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i)) # 当以i为stride等分num_timesteps后的序列个数刚好与desired_count相等是返回
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")] # 如果不是以ddim开头的字符,就用','将其分割得到section_counts
size_per = num_timesteps // len(section_counts) # 先将num_timesteps等分为len(section_counts)个大小为size_per的区间
extra = num_timesteps % len(section_counts) # 最后可能存在的余数
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0) # 将可能多于出来的extra均匀的分给前extra个区间,每个区间多加一个
if size < section_count: # 因为会把size等分为section_count份,故size必须不能小于section_count
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
# 设置两个时间步时间的stride大小
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = [] # 记录当前区间中等分缩放后的时间步t
for _ in range(section_count): # 需要添加section_count个时间步
taken_steps.append(start_idx + round(cur_idx)) # 添加新计算的时间步t,round是对cur_idx进行四舍五入
cur_idx += frac_stride
all_steps += taken_steps # 将当前区间内记录的时间步添加在总时间步中
start_idx += size # 更新下个区间时间步t开始的索引
return set(all_steps) # 返回最后基于section_counts等分缩放后的时间步序列
# 该类型继承GaussianDiffusion类的目的时,使用从原始时间步序列中调整而来的短的时间步序列生成新的betas后,
# 赋值给扩散过程对象,实现预测时用更少扩散时间步扩散生成图像,即论文中提高的“采样速度改善”
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.一个可以跳过基础扩散过程的步骤的扩散过程
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps) # 从原始时间步数中经过调整后的时间步序列
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"]) # 原始扩散过程的时间步数
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): # $\bar{\alpha}_t$
# alpha_cumprod是原始时间步i对应的$\bar{\alpha}_i$,i就表示时间步t
if i in self.use_timesteps: # 当前时间步i正好在调整后的时间步中use_timesteps时
# 基于$\bar{\alpha}_i$和$\bar{\alpha}_{i-1}$计算得到$\beta_i$
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas) # 给扩散过程对象设置新的betas
# 是现在上面用常规的betas初始化扩散过程对象后,基于从原始时间步中调整后的时间步序列生成对应的新的betas
# 然后再将扩散过程对象调整为空间改动后的对象,主要用于使用更少的扩散时间步进行图片生成,训练时还是用GaussianDiffusion对象进行训练
super().__init__(**kwargs)
# 神经网络预测的均值和方差
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
# 计算训练损失
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
# 将输入的model用_WrappedModel类包装
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
# 对模型包裹后,可对时间步序列进行缩放
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model # 被包装模型
self.timestep_map = timestep_map # 由原始时间序列调整后的时间序列
self.rescale_timesteps = rescale_timesteps # 时间序列是否进行了调整
self.original_num_steps = original_num_steps # 原始的时间步数
# 相当于对model的forward函数进行了一次包装;forward函数源码还是也是调用的__call__函数
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts] # 用ts从由原始时间序列调整后的时间序列中获取对应的时间步new_ts
if self.rescale_timesteps: # 如果时经过调整后的时间步,需要乘上对应的时间缩放比例
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs) # 调用model的forward函数,预测x_0;此处的model就是Unet
在此补充上数据集构建的代码解释,主要内容就是定义一个常规的图像类型的Dataset和dataLoader;需要注意的一点就是如果图片具有标签信息需要进行有条件的训练和采样,需要将图片的标签信息与对应的图片一起构造到Dataset中传入模型。
from PIL import Image
import blobfile as bf
from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
def load_data(*, data_dir, batch_size, image_size, class_cond=False, deterministic=False):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
if not data_dir:
raise ValueError("unspecified data directory")
all_files = _list_image_files_recursively(data_dir) # 找到路径下所有的图片的路径名
classes = None
if class_cond: # 为True时需要找到图片的类别信息
# Assume classes are the first part of the filename, before an underscore.
# 假设图片对应的label就存在图片保存路径中
class_names = [bf.basename(path).split("_")[0] for path in all_files]
sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} # 给字符串类型的label设置对应的数值标签
classes = [sorted_classes[x] for x in class_names] # 为所有的图片添加对应的数值label
dataset = ImageDataset(
image_size,
all_files,
classes=classes,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
) # 构建数据集
if deterministic: # 图片顺序不打乱
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
)
else: # 图片顺序打乱
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
)
while True:
yield from loader
# 递归寻找传入路径下所有的图片
def _list_image_files_recursively(data_dir):
results = []
for entry in sorted(bf.listdir(data_dir)):
full_path = bf.join(data_dir, entry)
ext = entry.split(".")[-1]
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
results.append(full_path)
elif bf.isdir(full_path):
results.extend(_list_image_files_recursively(full_path))
return results
# 图片数据集类
class ImageDataset(Dataset):
def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1):
super().__init__()
self.resolution = resolution # 分辨率
self.local_images = image_paths[shard:][::num_shards] # 图片的路径
self.local_classes = None if classes is None else classes[shard:][::num_shards] # 图片对应的标签
def __len__(self):
return len(self.local_images)
def __getitem__(self, idx):
path = self.local_images[idx]
with bf.BlobFile(path, "rb") as f:
pil_image = Image.open(f)
pil_image.load() # 加载图片
# We are not on a new enough PIL to support the `reducing_gap`
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
while min(*pil_image.size) >= 2 * self.resolution:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
) # 图片resize到标准形状
scale = self.resolution / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image.convert("RGB")) # 转为rgb格式
crop_y = (arr.shape[0] - self.resolution) // 2
crop_x = (arr.shape[1] - self.resolution) // 2
arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution]
arr = arr.astype(np.float32) / 127.5 - 1 # 归一化,范围[-1, 1]
out_dict = {}
if self.local_classes is not None:
out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
return np.transpose(arr, [2, 0, 1]), out_dict # 将图片数据转置,把维度信息放在最前面
本笔记主要记录IDDPM官方仓库中采样部分相关代码,其中包含了IDDPM的一个改善点,即采样速度改善。至此,已对IDDPM官方仓库中代码的主要部分完成了全部注释解析,之前已发布了模型和训练相关的两篇笔记分别是IDDPM复现gituhb项目–模型构建和IDDPM复现gituhb项目–训练,读者可参考此笔记IDDPM论文阅读辅助理解。
IDDPM官方仓库除了实现原始DDPM、IDDPM外,还实现DDIM方法,使得代码较多,理解可能比较复杂,本人发布的三篇注释类笔记中肯定也存在错误,读者若发现问题或错误,请评论指出,互相学习。因时间关系,DDIM方法对应代码还未进行注释,后续再找时间进行补充。