在完成IDDPM论文学习后,对github上的官方仓库进行学习,通过具体的代码理解算法实现过程中的一些细节;官方仓库代码基于pytorch实现,链接为https://github.com/openai/improved-diffusion。本笔记主要针对项目中训练部分代码进行注释解析,主要涉及仓库项目中的image_train.py、script_util.py、train_util.py、resample.py、dist_util.py文件。
本文件是进行图像训练的主要接口,先为训练过程中模型和扩散过程定义所需的参数,然后调用script_util.py文件中定义的函数初始化Unet模型和扩散过程对象,完成模型参数加载和训练数据导入后调用train_util.py文件中定义的TrainLoop类的run_loop()函数开始训练。
"""
Train a diffusion model on images.
"""
import argparse
from improved_diffusion import dist_util, logger
from improved_diffusion.image_datasets import load_data
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from improved_diffusion.train_util import TrainLoop
def main():
args = create_argparser().parse_args() # 设置模型和训练所需参数
dist_util.setup_dist() # 分布式训练
logger.configure()
logger.log("creating model and diffusion...")
# 初始化UNet和diffusion框架
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
# 返回前向过程中的时刻t的采样器,分均匀采样和基于loss的采样
# args.schedule_sampler设置为loss-second-moment可进行重要性采样,论文中用其在只优化L_vbl时减少梯度噪声
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
)
logger.log("training...")
# TrainLoop是主要训练对象
TrainLoop(
model=model, # 用于逆扩散过程中拟合p_theta的模型,一般是Unet
diffusion=diffusion, # 扩散过程对象
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler, # 训练过程中batch中数据时间步t的采样器
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
).run_loop()
# 初始化模型构建和训练相关的超参数
def create_argparser():
'''从字典中自动生成argument parser'''
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=10,
save_interval=10000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults()) # 添加模型和扩散过程的参数
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
本文件主要定义了Unet模型和扩散过程对象生成的代码,也包括超分辨率的Unet模型。
import argparse
import inspect
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
from .unet import SuperResModel, UNetModel
NUM_CLASSES = 1000
# Unet模型和扩散过程对象所需参数
def model_and_diffusion_defaults():
"""
Defaults for image training.
"""
return dict(
image_size=64,
num_channels=128,
num_res_blocks=2,
num_heads=4,
num_heads_upsample=-1,
attention_resolutions="16,8",
dropout=0.0,
learn_sigma=False,
sigma_small=False,
class_cond=False,
diffusion_steps=1000,
noise_schedule="linear",
timestep_respacing="",
use_kl=False,
predict_xstart=False,
rescale_timesteps=True,
rescale_learned_sigmas=True,
use_checkpoint=False,
use_scale_shift_norm=True,
)
# 生成Unet模型和高斯扩散过程对象
def create_model_and_diffusion(
image_size, # 图片大小
class_cond, # 生成模型是否有条件;一般就是图片有label信息
learn_sigma, # 设置模型是预测方差还是使用固定方差
sigma_small,
num_channels,
num_res_blocks,
num_heads,
num_heads_upsample,
attention_resolutions, # 在哪些restblock上进行attention;存放图片的分辨率,当图片降维至该分辨率屎进行自注意力计算
dropout,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
):
# Unet模型初始化
model = create_model(
image_size,
num_channels,
num_res_blocks,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
)
# 初始化高斯扩散过程
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
sigma_small=sigma_small,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
)
return model, diffusion
# 生成Unet模型
def create_model(
image_size,
num_channels,
num_res_blocks,
learn_sigma,
class_cond,
use_checkpoint,
attention_resolutions, # 表示Unet中进行自注意力计算是特征图的分辨率,就是尺寸大小,用于告诉模型何时进行自注意力计算
num_heads,
num_heads_upsample,
use_scale_shift_norm,
dropout,
):
# Unet架构中通道乘子,因为随着模型深入,特征图空间尺寸降低,但通道数逐渐增加
if image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
elif image_size == 32:
channel_mult = (1, 2, 2, 2)
else:
raise ValueError(f"unsupported image size: {image_size}")
attention_ds = []
for res in attention_resolutions.split(","): # attention_resolutions是[16, 8]
attention_ds.append(image_size // int(res)) # attention_ds是[4, 8],原始尺寸大小除以下采样后的大小就是下采样率
return UNetModel(
in_channels=3,
model_channels=num_channels,
out_channels=(3 if not learn_sigma else 6), # 如果设置可学习方差sigma,输出维度就是6,分成两部分,分别预测miu和sigma
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds), # 此处已经表示的是需要进行自注意力计算时的下采样率
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
)
# 超分辨率Unet模型和扩散过程对象所需参数
def sr_model_and_diffusion_defaults():
res = model_and_diffusion_defaults()
res["large_size"] = 256
res["small_size"] = 64
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
for k in res.copy().keys():
if k not in arg_names:
del res[k]
return res
# 生成超分辨率Unet和高斯扩散过程对象
def sr_create_model_and_diffusion(
large_size,
small_size,
class_cond,
learn_sigma,
num_channels,
num_res_blocks,
num_heads,
num_heads_upsample,
attention_resolutions,
dropout,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
):
model = sr_create_model(
large_size,
small_size,
num_channels,
num_res_blocks,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
)
return model, diffusion
# 生成超分辨率Unet
def sr_create_model(
large_size,
small_size,
num_channels,
num_res_blocks,
learn_sigma,
class_cond,
use_checkpoint,
attention_resolutions,
num_heads,
num_heads_upsample,
use_scale_shift_norm,
dropout,
):
_ = small_size # hack to prevent unused variable
if large_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif large_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported large size: {large_size}")
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(large_size // int(res))
return SuperResModel(
in_channels=3,
model_channels=num_channels,
out_channels=(3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
)
# 生成扩散过程的框架;虽然初始化的是SpacedDiffusion类,但只要不进行respace,就是一个常规的GaussianDiffusion类
def create_gaussian_diffusion(
*,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="linear",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
):
'''生成扩散模型的框架'''
betas = gd.get_named_beta_schedule(noise_schedule, steps) # 设置前向的加噪方案,即设置β;可选择设置IDDPM论文提出的余弦加噪方案
if use_kl:
loss_type = gd.LossType.RESCALED_KL # 只是用kl损失
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE # 使用混合损失
else:
loss_type = gd.LossType.MSE # 使用原始DDPM的损失
if not timestep_respacing:
timestep_respacing = [steps] # 调整时间步空间???
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
), # Unet模型预测的值是x_0还是均值
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
), # Unet模型预测的方差是可学习方差,还是使用的固定方差,固定方差中又分大的beta_t或小的beta_bar_t
loss_type=loss_type, # 损失类型
rescale_timesteps=rescale_timesteps,
)
# 将default_dict字典中的参数添加到parser对象中
def add_dict_to_argparser(parser, default_dict):
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)
# 从args中按传入的keys构建一个对应的参数字典
def args_to_dict(args, keys):
return {k: getattr(args, k) for k in keys}
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
本文件中主要定义了用于训练的TrainLoop类,其内部是将单个step、单个batch、整体训练过程解耦的,实现方式与PytorchLightning类似。本人认为需要注意的一点的是,为了保证模型进行合理的混合精度训练,TrainLoop类中维护了一个self.master_params变量。本人是将其理解为训练过程中Unet模型参数的一份全精度的备份,混合精度训练时,在训练过程中计算时,模型中是使用半精度类型数据进行计算,但是在进行梯度回传时,会将模型内部的梯度值传递到self.master_params变量中存储的参数进行全精度的参数更新,然后再把更新后的参数传回到模型参数中,完成梯度回传和模型参数更新。训练过程中,会使用到IDDPM提出的一个改善点,即对时间步进行基于损失的重要性重采样。
import copy
import functools
import os
import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from . import dist_util, logger
from .fp16_util import (
make_master_params,
master_params_to_model_params,
model_grads_to_master_grads,
unflatten_master_params,
zero_grad,
)
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler
# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0
# 定义的模型训练类,有点类似PytorchLightning,将训练过程封装为一个接口
class TrainLoop:
def __init__(
self,
*,
model,
diffusion,
data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
):
self.model = model # Unet模型
self.diffusion = diffusion # 扩散过程对象
self.data = data # 训练数据
self.batch_size = batch_size
self.microbatch = microbatch if microbatch > 0 else batch_size # 多卡训练单卡上的batch???
self.lr = lr
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")])
self.log_interval = log_interval # 日志记录间隔
self.save_interval = save_interval # 模型保存间隔
self.resume_checkpoint = resume_checkpoint
self.use_fp16 = use_fp16 # 是否进行半精度训练
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) # 时间步采样器
self.weight_decay = weight_decay
self.lr_anneal_steps = lr_anneal_steps # 学习率回火steps
self.step = 0
self.resume_step = 0
self.global_batch = self.batch_size * dist.get_world_size()
self.model_params = list(self.model.parameters())
self.master_params = self.model_params
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
self.sync_cuda = th.cuda.is_available()
self._load_and_sync_parameters() # Unet模型加载参数并同步到多卡上
if self.use_fp16: # 如果使用半精度训练
self._setup_fp16() # 先将模型参数以全精度形式备份,然后将其转为半精度
self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) # 优化器
if self.resume_step: # 如果时在断点上继续训练
self._load_optimizer_state() # 给优化器加载参数
# Model was resumed, either due to a restart or a checkpoint being specified at the command line.
self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate]
else:
self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))]
if th.cuda.is_available():
self.use_ddp = True # 分布式训练
self.ddp_model = DDP(
self.model,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
if dist.get_world_size() > 1:
logger.warn(
"Distributed training requires CUDA. "
"Gradients will not be synchronized properly!")
self.use_ddp = False
self.ddp_model = self.model
# 分布式训练中模型加载和同步所有参数
def _load_and_sync_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint # 已经训练过程checkpoint保存的文件路径
if resume_checkpoint:
self.resume_step = parse_resume_step_from_filename(resume_checkpoint) # 从文件路径中解析处保存时的step
if dist.get_rank() == 0:
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
self.model.load_state_dict(
dist_util.load_state_dict(
resume_checkpoint, map_location=dist_util.dev()
)
) # 给Unet模型加载参数
dist_util.sync_params(self.model.parameters()) # 给多卡上的模型同步参数
def _load_ema_parameters(self, rate):
ema_params = copy.deepcopy(self.master_params)
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
if ema_checkpoint:
if dist.get_rank() == 0:
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
state_dict = dist_util.load_state_dict(
ema_checkpoint, map_location=dist_util.dev())
ema_params = self._state_dict_to_master_params(state_dict)
dist_util.sync_params(ema_params)
return ema_params
# 给优化器加载参数
def _load_optimizer_state(self):
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
opt_checkpoint = bf.join(
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") # 优化器参数保存的文件路径
if bf.exists(opt_checkpoint):
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
state_dict = dist_util.load_state_dict(
opt_checkpoint, map_location=dist_util.dev())
self.opt.load_state_dict(state_dict)
# 将模型的参数设置为半精度
def _setup_fp16(self):
self.master_params = make_master_params(self.model_params) # 先将模型的参数以全精度的格式备份一份
self.model.convert_to_fp16() # 让后再将模型所有的参数转为半精度
# 主要的训练函数
def run_loop(self):
while (
not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps
):
batch, cond = next(self.data) # 一个batch的数据,cond应该是label等条件信息
self.run_step(batch, cond) # 执行一个batch的训练过程
if self.step % self.log_interval == 0:
logger.dumpkvs()
if self.step % self.save_interval == 0:
self.save() # 模型、优化器等参数保存
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
self.step += 1
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
# 单个batch的训练函数
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
if self.use_fp16:
self.optimize_fp16() #
else:
self.optimize_normal() # 优化器更新
self.log_step()
def forward_backward(self, batch, cond):
zero_grad(self.model_params) # 清除模型参数的梯度
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i: i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i: i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()}
last_batch = (i + self.microbatch) >= batch.shape[0] # 随后的一个microbatch
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) # 时间步采样
# 使用functools.partial接口传入在扩散过程对象中定义损失计算函数和所需参数定义了一个用于计算损失函数
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,)
# 实际损失计算
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync(): # 在非同步的情况下计算损失
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()) # 论文中提出的改善点,基于训练损失的时间步重要性重采样
loss = (losses["loss"] * weights).mean()
log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) # 记录训练损失
# 梯度回传
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
# 半精度训练时优化器更新
def optimize_fp16(self):
if any(not th.isfinite(p.grad).all() for p in self.model_params):
self.lg_loss_scale -= 1
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
return
model_grads_to_master_grads(self.model_params, self.master_params) # 将当前模型参数的梯度赋值给self.master_params,即备份参数
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
self._log_grad_norm() # 记录模型所有参数计算的正则项
self._anneal_lr() # 学习率回火
self.opt.step() # 优化器更新
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate) # 使用指数移动平均值更新目标参数,使其更接近源参数;即模型参数更新
# 将备份的参数再传回给模型参数,这么做的目的应该是训练可以使用半精度,但是在模型梯度更新时还是要使用全精度
master_params_to_model_params(self.model_params, self.master_params)
self.lg_loss_scale += self.fp16_scale_growth
# 全精度训练时优化器更新
def optimize_normal(self):
self._log_grad_norm() # 记录模型所有参数计算的正则项
self._anneal_lr() # 学习率回火
self.opt.step() # 优化器更新
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate) # 使用指数移动平均值更新目标参数,使其更接近源参数;即模型参数更新
def _log_grad_norm(self):
sqsum = 0.0
for p in self.master_params:
sqsum += (p.grad ** 2).sum().item()
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
lr = self.lr * (1 - frac_done) # 新的学习率值
for param_group in self.opt.param_groups:
param_group["lr"] = lr # 给优化器各参数更新学习率
def log_step(self):
logger.logkv("step", self.step + self.resume_step)
logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
if self.use_fp16:
logger.logkv("lg_loss_scale", self.lg_loss_scale)
# 模型参数保存
def save(self):
def save_checkpoint(rate, params):
state_dict = self._master_params_to_state_dict(params)
if dist.get_rank() == 0:
logger.log(f"saving model {rate}...")
if not rate:
filename = f"model{(self.step + self.resume_step):06d}.pt"
else:
filename = f"ema_{rate}_{(self.step + self.resume_step):06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(state_dict, f)
save_checkpoint(0, self.master_params) # 保存模型参数
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params) # 保存ema_params参数
if dist.get_rank() == 0:
with bf.BlobFile(
bf.join(get_blob_logdir(), f"opt{(self.step + self.resume_step):06d}.pt"),
"wb",
) as f:
th.save(self.opt.state_dict(), f) # 保存优化器参数
dist.barrier() # 所有参数同步
def _master_params_to_state_dict(self, master_params):
if self.use_fp16:
master_params = unflatten_master_params(self.model.parameters(), master_params) # 将master_params尺寸还原
state_dict = self.model.state_dict() # 模型所有的参数
for i, (name, _value) in enumerate(self.model.named_parameters()):
assert name in state_dict
state_dict[name] = master_params[i] # 以模型参数中name从备份参数master_parems中找对应的值
return state_dict
# 从state_dict中将参数传给master_params中
def _state_dict_to_master_params(self, state_dict):
params = [state_dict[name] for name, _ in self.model.named_parameters()]
if self.use_fp16:
return make_master_params(params)
else:
return params
# 从一个checkpoint文件路径名中解析保存时的step
def parse_resume_step_from_filename(filename):
"""
Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
checkpoint's number of steps.
"""
split = filename.split("model")
if len(split) < 2:
return 0
split1 = split[-1].split(".")[0]
try:
return int(split1)
except ValueError:
return 0
def get_blob_logdir():
return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
def find_resume_checkpoint():
# On your infrastructure, you may want to override this to automatically
# discover the latest checkpoint on your blob storage, etc.
return None
def find_ema_checkpoint(main_checkpoint, step, rate):
if main_checkpoint is None:
return None
filename = f"ema_{rate}_{(step):06d}.pt"
path = bf.join(bf.dirname(main_checkpoint), filename)
if bf.exists(path):
return path
return None
# 记录损失
def log_loss_dict(diffusion, ts, losses):
for key, values in losses.items():
logger.logkv_mean(key, values.mean().item())
# Log the quantiles (four quartiles, in particular).
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
quartile = int(4 * sub_t / diffusion.num_timesteps)
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
IDDPM论文中提到,训练时对时间步进行均匀采样会在 L v l b L_{vlb} Lvlb中引入不必要的噪声,为了解决该问题就是进行重要性重采样,具体的实现方式就是在随机进行时间步采样时,会使用一个动态更新的历史损失为各个时间步计算对应的采样权重,而不是进行各个时间步采样概率相同的均匀采样。该动态更新的历史损失,就是构造一个尺寸为[T, 10]的矩阵,即为整个扩散过程的T步存储10个最新的损失值;在将该历史损失矩阵所有的值填满前,还是进行均匀采样;在填满之后,基于损失为每个时间步计算随机采样时的权重,为下一个step训练时更新时间步的采样权重;并且如果使用该类型的时间步sampler,在一个step中损失计算后,还需要调用update_with_local_losses函数将新计算得到的损失填入到历史损失矩阵最后端,并将最前端的最旧的历史损失弹出。
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
# 返回时间步的采样器
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion) # 均匀采样
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion) # 基于二阶动量平滑loss
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.扩散过程中随时间步长的分布,旨在减少目标的方差
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.默认情况下,采样器执行无偏重要性抽样,其中目标的均值保持不变。
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
但是,子类可以覆盖 sample() 以更改重新采样项的重新加权方式,从而允许目标的实际更改。
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
# 一个batch内数据的重要性采样时间步
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights() # 所有时间步的权重
p = w / np.sum(w) # 每个时间步的权重除去权重之和
# 从range(len(p))中以概率p随机抽取大小为size的数据;p指定的是序列range(len(p))中每个元素出现的概率
indices_np = np.random.choice(len(p), size=(batch_size,), p=p) # 相当于是概率p为指导从range(len(p))随机采样了batch_size个t对应的索引
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np]) # 为batch中每个对象设置新的权重
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
# 时间步均匀采样
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
# 权重均为1,使得ScheduleSampler的sample函数中的p的概率都是一样的,故使用np.random.choice采样时是均匀采样
self._weights = np.ones([diffusion.num_timesteps])
def weights(self): # 重载ScheduleSampler中的weights函数
return self._weights
# 使用损失更新weights的重要性采样
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.使用模型中的损失更新重新加权
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.时间步的整数张量
:param local_losses: a 1D Tensor of losses.损失的一维张量
"""
batch_sizes = [
th.tensor([0], dtype=th.int32, device=local_ts.device)
for _ in range(dist.get_world_size())
]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes) # 将多卡上的batch整合后的最大batch_size
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term # 论文中提到的“保留每个损失项的前10个值”,是针对每个时间步t
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
) # diffusion.num_timesteps是设置的训练采样的总步数,即T;故self._loss_history是为0到T-1中的每个时间步t存放10个损失值
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) # 表征self._loss_history中对应列是否填充
def weights(self): # 重载的weights的函数
if not self._warmed_up(): # 未完成warm_up就进行均匀采样;即self._loss_history中数据还未填满
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
# 基于历史损失的权重更新
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) # 对历史损失平方后取均值再开方
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
# 会改变类中的self._loss_history进而改变self.weights的返回数据,进而改变self.sampler中的采样结果
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term: # 如果self._loss_history已经填满
# Shift out the oldest loss term.移除第一列损失,将新的损失补充为最后一列
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else: # 如果self._loss_history未填满
self._loss_history[t, self._loss_counts[t]] = loss # 用新传入的损失补充为最后一列
self._loss_counts[t] += 1 # 填充列数加一
# 用于判断self._loss_history中的数据是否填满,填满之前都是进行均匀采样
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
本文件主要为多卡分布式训练定义一些辅助函数
"""
Helpers for distributed training.
"""
import io
import os
import socket
import blobfile as bf
from mpi4py import MPI
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def setup_dist():
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
comm = MPI.COMM_WORLD
backend = "gloo" if not th.cuda.is_available() else "nccl"
if backend == "gloo":
hostname = "localhost"
else:
hostname = socket.gethostbyname(socket.getfqdn())
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
os.environ["RANK"] = str(comm.rank)
os.environ["WORLD_SIZE"] = str(comm.size)
port = comm.bcast(_find_free_port(), root=0)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend, init_method="env://")
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
else:
data = None
data = MPI.COMM_WORLD.bcast(data)
return th.load(io.BytesIO(data), **kwargs)
# 从主GPU开始同步张量序列
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()
本笔记主要记录IDDPM官方仓库中训练部分相关代码,其中包含了IDDPM的一个改善点,即基于损失的时间步重要性重采样。本笔记中的项目代码虽然没有模型构造中那么多公式,但也最好能与论文对比学习,读者可参考此笔记IDDPM论文阅读辅助理解。读者若发现问题或错误,请评论指出,互相学习。