import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
加载预训练管线
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
image_pipe.to(device);
加载数据集
# @markdown load and prepare a dataset:
# Not on Colab? Comments with #@ enable UI tweaks like headings or user inputs
# but can safely be ignored if you're working on a different platform.
dataset_name = "huggan/smithsonian_butterflies_subset" # @param
dataset = load_dataset(dataset_name, split="train")
image_size = 256 # @param
batch_size = 4 # @param
preprocess = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True
)
print("Previewing batch:")
batch = next(iter(train_dataloader))
grid = torchvision.utils.make_grid(batch["images"], nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5);
开始训练
num_epochs = 2 # @param
lr = 1e-5 # 2param
grad_accumulation_steps = 2 # @param
optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr)
losses = []
for epoch in range(num_epochs):
for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
clean_images = batch["images"].to(device)
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
image_pipe.scheduler.num_train_timesteps,
(bs,),
device=clean_images.device,
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = image_pipe.scheduler.add_noise(clean_images, noise, timesteps)
# Get the model prediction for the noise
noise_pred = image_pipe.unet(noisy_images, timesteps, return_dict=False)[0]
# Compare the prediction with the actual noise:
loss = F.mse_loss(
noise_pred, noise
) # NB - trying to predict noise (eps) not (noisy_ims-clean_ims) or just (clean_ims)
# Store for later plotting
losses.append(loss.item())
# Update the model parameters with the optimizer based on this loss
loss.backward(loss)
# Gradient accumulation:
if (step + 1) % grad_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(
f"Epoch {epoch} average loss: {sum(losses[-len(train_dataloader):])/len(train_dataloader)}"
)
# Plot the loss curve:
plt.plot(losses)
如果我们想对生成的样本施加点控制,那需要怎么做呢?例如,我们想让生成的图片偏向于靠近某种颜色。该怎么做呢?这里我们要介绍引导(guidance),它可以用来在采样的过程中施加额外控制。
第一步,我们先创建一个函数,定义我们希望优化的一个指标(损失值)。这里是一个让生成的图片趋向于某种颜色的例子,它将图片像素值和目标颜色(这里用的是一种浅蓝绿色)对比,返回平均的误差:
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
"""Given a target color (R, G, B) return a loss for how far away on average
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
target = (
torch.tensor(target_color).to(images.device) * 2 - 1
) # Map target color to (-1, 1)
target = target[
None, :, None, None
] # Get shape right to work with the images (b, c, h, w)
error = torch.abs(
images - target
).mean() # Mean absolute difference between the image pixels and the target color
return error
接下来,我们要修改采样循环,在每一步,我们要做这些事情:
这里有两种实现方法。第一,我们是在从 UNet 得到噪声预测后才给 x 设置 requires_grad 的,这样对内存来讲更高效一点(因为我们不用穿过扩散模型去追踪梯度),但这样做梯度的精度会低一点。
# Variant 1: shortcut method
# The guidance scale determines the strength of the effect
guidance_loss_scale = 40 # Explore changing this to 5, or 100
x = torch.randn(8, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
# Prepare the model input
model_input = scheduler.scale_model_input(x, t)
# predict the noise residual
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
# Set x.requires_grad to True
x = x.detach().requires_grad_()
# Get the predicted x0
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = color_loss(x0) * guidance_loss_scale
if i % 10 == 0:
print(i, "loss:", loss.item())
# Get gradient
cond_grad = -torch.autograd.grad(loss, x)[0]
# Modify x based on this gradient
x = x.detach() + cond_grad
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
# View the output
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
第二种方法是,我们先给 x 设置 requires_grad,然后再送入 UNet 并计算预测出的 x0。
# Variant 2: setting x.requires_grad before calculating the model predictions
guidance_loss_scale = 40
x = torch.randn(4, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
# Set requires_grad before the model forward pass
x = x.detach().requires_grad_()
model_input = scheduler.scale_model_input(x, t)
# predict (with grad this time)
noise_pred = image_pipe.unet(model_input, t)["sample"]
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = color_loss(x0) * guidance_loss_scale
if i % 10 == 0:
print(i, "loss:", loss.item())
# Get gradient
cond_grad = -torch.autograd.grad(loss, x)[0]
# Modify x based on this gradient
x = x.detach() + cond_grad
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
CLIP 是一个由 OpenAI 开发的模型,它可以让我们拿图片和文字说明去作比较。这是个非常强大的功能,因为它让我们能量化一张图和一句提示语有多匹配。另外,由于这个过程是可微分的,我们可以使用它作为损失函数去引导我们的扩散模型。
基本的方法是:
# @markdown load a CLIP model and define the loss function
import open_clip
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224), # Random CROP each time
torchvision.transforms.RandomAffine(
5
), # One possible random augmentation: skews the image
torchvision.transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
# And define a loss function that takes an image, embeds it and compares with
# the text features of the prompt
def clip_loss(image, text_features):
image_features = clip_model.encode_image(
tfms(image)
) # Note: applies the above transforms
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
) # Squared Great Circle Distance
return dists.mean()
# @markdown applying guidance using CLIP
prompt = "Red Rose (still life), red flower painting" # @param
# Explore changing this
guidance_scale = 8 # @param
n_cuts = 4 # @param
# More steps -> more time for the guidance to have an effect
scheduler.set_timesteps(50)
# We embed a prompt with CLIP as our target
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
x = torch.randn(4, 3, 256, 256).to(
device
) # RAM usage is high, you may want only 1 image at a time
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
# predict the noise residual
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
cond_grad = 0
for cut in range(n_cuts):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = clip_loss(x0, text_features) * guidance_scale
# Get gradient (scale by n_cuts since we want the average)
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
if i % 25 == 0:
print("Step:", i, ", Guidance loss:", loss.item())
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x = (
x.detach() + cond_grad * alpha_bar.sqrt()
) # Note the additional scaling factor here!
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x.detach(), nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
我们输入类别这一条件的方法是:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
加载数据集
# Load the dataset
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
# Feed it into a dataloader (batch size 8 here just for demo)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# View some examples
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
修改Unet2D模型
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
# The embedding layer will map the class label to a vector of size class_emb_size
self.class_emb = nn.Embedding(num_classes, class_emb_size)
# Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
self.model = UNet2DModel(
sample_size=28, # the target image resolution
in_channels=1 + class_emb_size, # Additional input channels for class cond.
out_channels=1, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(32, 64, 64),
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D", # a regular ResNet upsampling block
),
)
# Our forward method now takes the class labels as an additional argument
def forward(self, x, t, class_labels):
# Shape of x:
bs, ch, w, h = x.shape
# class conditioning in right shape to add as additional input channels
class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
# Net input is now x and class cond concatenated together along dimension 1
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
# Feed this to the unet alongside the timestep and return the prediction
return self.model(net_input, t).sample # (bs, 1, 28, 28)
模型训练
# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
#@markdown Training loop (10 Epochs):
# Redefining the dataloader to set the batch size higher than the demo of 8
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# How many runs through the data should we do?
n_epochs = 10
# Our network
net = ClassConditionedUnet().to(device)
# Our loss finction
loss_fn = nn.MSELoss()
# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Keeping a record of the losses for later viewing
losses = []
# The training loop
for epoch in range(n_epochs):
for x, y in tqdm(train_dataloader):
# Get some data and prepare the corrupted version
x = x.to(device) * 2 - 1 # Data on the GPU (mapped to (-1, 1))
y = y.to(device)
noise = torch.randn_like(x)
timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
# Get the model prediction
pred = net(noisy_x, timesteps, y) # Note that we pass in the labels y
# Calculate the loss
loss = loss_fn(pred, noise) # How close is the output to the noise
# Backprop and update the params:
opt.zero_grad()
loss.backward()
opt.step()
# Store the loss for later
losses.append(loss.item())
# Print our the average of the last 100 loss values to get an idea of progress:
avg_loss = sum(losses[-100:])/100
print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')
# View the loss curve
plt.plot(losses)