class ContextUnet(nn.Module):
def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
super(ContextUnet, self).__init__()
# number of input channels, number of intermediate feature maps and number of classes
# 输入通道数
self.in_channels = in_channels
# 映射特征数量
self.n_feat = n_feat
# 生成类别数
self.n_cfeat = n_cfeat
# 生成的是方形图,并且输入必须能够被4整除
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
# Initialize the initial convolutional layer
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
# 初始化下采样层
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
# 仅仅进行平均池化,并没有改变他的通道数
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
# Embed the timestep and context labels with a one-layer fully connected neural network
# 定义两个嵌入层,将时间戳信息和上下文消息都转为对应的embedding向量
# 这里仅仅是改变通道数,并没有改变上下文信息的特征
self.timeembed1 = EmbedFC(1, 2*n_feat)
self.timeembed2 = EmbedFC(1, 1*n_feat)
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
# Initialize the up-sampling path of the U-Net with three levels
# 并不改变通道数,仅仅是进行上采样
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample
nn.GroupNorm(8, 2 * n_feat), # normalize
nn.ReLU(),
)
# 降低通道数,并进行上采样,同下
self.up1 = UnetUp(4 * n_feat, n_feat)
# 降低通道数,并进行上采样,这里输入通道和up1的输出通道不同,是因为还有上下文信息和之前下采样的输出
self.up2 = UnetUp(2 * n_feat, n_feat)
# 初始化最终的卷积层,将最终的输出映射为和输入相同大小
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
nn.GroupNorm(8, n_feat), # normalize
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
)
网络结构的每一层参数如下
# 初始化卷积层
init_conv.conv1.0.weight torch.Size([256, 3, 3, 3])
init_conv.conv1.0.bias torch.Size([256])
init_conv.conv1.1.weight torch.Size([256])
init_conv.conv1.1.bias torch.Size([256])
init_conv.conv2.0.weight torch.Size([256, 256, 3, 3])
init_conv.conv2.0.bias torch.Size([256])
init_conv.conv2.1.weight torch.Size([256])
init_conv.conv2.1.bias torch.Size([256])
# 下采样层一
down1.model.0.conv1.0.weight torch.Size([256, 256, 3, 3])
down1.model.0.conv1.0.bias torch.Size([256])
down1.model.0.conv1.1.weight torch.Size([256])
down1.model.0.conv1.1.bias torch.Size([256])
down1.model.0.conv2.0.weight torch.Size([256, 256, 3, 3])
down1.model.0.conv2.0.bias torch.Size([256])
down1.model.0.conv2.1.weight torch.Size([256])
down1.model.0.conv2.1.bias torch.Size([256])
down1.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
down1.model.1.conv1.0.bias torch.Size([256])
down1.model.1.conv1.1.weight torch.Size([256])
down1.model.1.conv1.1.bias torch.Size([256])
down1.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
down1.model.1.conv2.0.bias torch.Size([256])
down1.model.1.conv2.1.weight torch.Size([256])
down1.model.1.conv2.1.bias torch.Size([256])
# 下采样层二
down2.model.0.conv1.0.weight torch.Size([512, 256, 3, 3])
down2.model.0.conv1.0.bias torch.Size([512])
down2.model.0.conv1.1.weight torch.Size([512])
down2.model.0.conv1.1.bias torch.Size([512])
down2.model.0.conv2.0.weight torch.Size([512, 512, 3, 3])
down2.model.0.conv2.0.bias torch.Size([512])
down2.model.0.conv2.1.weight torch.Size([512])
down2.model.0.conv2.1.bias torch.Size([512])
down2.model.1.conv1.0.weight torch.Size([512, 512, 3, 3])
down2.model.1.conv1.0.bias torch.Size([512])
down2.model.1.conv1.1.weight torch.Size([512])
down2.model.1.conv1.1.bias torch.Size([512])
down2.model.1.conv2.0.weight torch.Size([512, 512, 3, 3])
down2.model.1.conv2.0.bias torch.Size([512])
down2.model.1.conv2.1.weight torch.Size([512])
down2.model.1.conv2.1.bias torch.Size([512])
# 时间上下文信息embedding
timeembed1.model.0.weight torch.Size([512, 1])
timeembed1.model.0.bias torch.Size([512])
timeembed1.model.2.weight torch.Size([512, 512])
timeembed1.model.2.bias torch.Size([512])
timeembed2.model.0.weight torch.Size([256, 1])
timeembed2.model.0.bias torch.Size([256])
timeembed2.model.2.weight torch.Size([256, 256])
timeembed2.model.2.bias torch.Size([256])
# 上下文信息的embedding
contextembed1.model.0.weight torch.Size([512, 10])
contextembed1.model.0.bias torch.Size([512])
contextembed1.model.2.weight torch.Size([512, 512])
contextembed1.model.2.bias torch.Size([512])
contextembed2.model.0.weight torch.Size([256, 10])
contextembed2.model.0.bias torch.Size([256])
contextembed2.model.2.weight torch.Size([256, 256])
contextembed2.model.2.bias torch.Size([256])
# 上采样零层,如果不用加上上下文信息,这层完全没有必要,现在是加上了。
up0.0.weight torch.Size([512, 512, 7, 7])
up0.0.bias torch.Size([512])
up0.1.weight torch.Size([512])
up0.1.bias torch.Size([512])
up1.model.0.weight torch.Size([1024, 256, 2, 2])
up1.model.0.bias torch.Size([256])
# 上采样一层
up1.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
up1.model.1.conv1.0.bias torch.Size([256])
up1.model.1.conv1.1.weight torch.Size([256])
up1.model.1.conv1.1.bias torch.Size([256])
up1.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
up1.model.1.conv2.0.bias torch.Size([256])
up1.model.1.conv2.1.weight torch.Size([256])
up1.model.1.conv2.1.bias torch.Size([256])
up1.model.2.conv1.0.weight torch.Size([256, 256, 3, 3])
up1.model.2.conv1.0.bias torch.Size([256])
up1.model.2.conv1.1.weight torch.Size([256])
up1.model.2.conv1.1.bias torch.Size([256])
up1.model.2.conv2.0.weight torch.Size([256, 256, 3, 3])
up1.model.2.conv2.0.bias torch.Size([256])
up1.model.2.conv2.1.weight torch.Size([256])
up1.model.2.conv2.1.bias torch.Size([256])
# 上采样二层
up2.model.0.weight torch.Size([512, 256, 2, 2])
up2.model.0.bias torch.Size([256])
up2.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
up2.model.1.conv1.0.bias torch.Size([256])
up2.model.1.conv1.1.weight torch.Size([256])
up2.model.1.conv1.1.bias torch.Size([256])
up2.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
up2.model.1.conv2.0.bias torch.Size([256])
up2.model.1.conv2.1.weight torch.Size([256])
up2.model.1.conv2.1.bias torch.Size([256])
up2.model.2.conv1.0.weight torch.Size([256, 256, 3, 3])
up2.model.2.conv1.0.bias torch.Size([256])
up2.model.2.conv1.1.weight torch.Size([256])
up2.model.2.conv1.1.bias torch.Size([256])
up2.model.2.conv2.0.weight torch.Size([256, 256, 3, 3])
up2.model.2.conv2.0.bias torch.Size([256])
up2.model.2.conv2.1.weight torch.Size([256])
up2.model.2.conv2.1.bias torch.Size([256])
# 最终的输出层,将输出的通道进行调整为3
out.0.weight torch.Size([256, 512, 3, 3])
out.0.bias torch.Size([256])
out.1.weight torch.Size([256])
out.1.bias torch.Size([256])
out.3.weight torch.Size([3, 256, 3, 3])
out.3.bias torch.Size([3])
当前网络每一层输出的张量情况
# 输入的图片为[32,3,28,28]=[batch_size,channel,height,width]
# 提取特征,扩充通道数
Layer: ResidualConvBlock
Input shape: torch.Size([32, 3, 28, 28])
Output shape: torch.Size([32, 64, 28, 28])
==============================
# 下采样层一:尺寸减半,通道数不变
Layer: UnetDown
Input shape: torch.Size([32, 64, 28, 28])
Output shape: torch.Size([32, 64, 14, 14])
==============================
# 下采样层二:尺寸减半,通道数翻倍
Layer: UnetDown
Input shape: torch.Size([32, 64, 14, 14])
Output shape: torch.Size([32, 128, 7, 7])
==============================
# 还是对输入的特征图进行下采样,是4*4的方格进行下采样
Layer: Sequential
Input shape: torch.Size([32, 128, 7, 7])
Output shape: torch.Size([32, 128, 1, 1])
==============================
# 下述四层为上下文信息处理层,分别处理上下文类别信息和时间序列信息,分层加入到模型中
# 下述为特征上下文信息,每一个样本都有自己的特征上下文
Layer: EmbedFC
Input shape: torch.Size([32, 5])
Output shape: torch.Size([32, 128])
==============================
# 下述为时间序列上下文,所有样本的时间序列是统一的
Layer: EmbedFC
Input shape: torch.Size([1, 1, 1, 1])
Output shape: torch.Size([1, 128])
==============================
# 下述为经过扩展的样本上下文,用于加到第二个上采样层
Layer: EmbedFC
Input shape: torch.Size([32, 5])
Output shape: torch.Size([32, 64])
==============================
# 下述为经过扩展的时间序列信息,用于加到第二个上采样层
Layer: EmbedFC
Input shape: torch.Size([1, 1, 1, 1])
Output shape: torch.Size([1, 64])
==============================
# 上采样层零:扩展维度,对应两个下采样层下的第一个卷积层
Layer: Sequential
Input shape: torch.Size([32, 128, 1, 1])
Output shape: torch.Size([32, 128, 7, 7])
==============================
# 上采样层一
Layer: UnetUp
Input shape: torch.Size([32, 128, 7, 7])
Output shape: torch.Size([32, 64, 14, 14])
==============================
# 上采样层二
Layer: UnetUp
Input shape: torch.Size([32, 64, 14, 14])
Output shape: torch.Size([32, 64, 28, 28])
==============================
# 输出调整层,将输出的信道调整为原始图层
Layer: Sequential
Input shape: torch.Size([32, 128, 28, 28])
Output shape: torch.Size([32, 3, 28, 28])
==============================
网络各层的连接方式
def forward(self, x, t, c=None):
"""
x : (batch, n_feat, h, w) : input image
t : (batch, n_cfeat) : time step
c : (batch, n_classes) : context label
"""
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
'''下采样过程'''
# 将输入的图片传入初始化卷积层中
x = self.init_conv(x)
# 将结果传入下采样层
down1 = self.down1(x) #[10, 256, 8, 8]
down2 = self.down2(down1) #[10, 256, 4, 4]
# 将特征映射为向量
hiddenvec = self.to_vec(down2)
'''上采样过程'''
# mask out context if context_mask == 1
# 判定是否有上下文信息
if c is None:
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
# 将上下文信息context information还有timestep转为embedding
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
# 上采样过程,分别和对应下采样对应层和对应上下文信息加入到每一个上采样层中
up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2*up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
# 前向传播公示
return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise
下述为具体的训练代码
# training without context code
# set into train mode
nn_model.train()
for ep in range(n_epoch):
print(f'epoch {ep}')
# linearly decay learning rate
# 定义学习率进行线性衰减
optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
# 加载进度条
pbar = tqdm(dataloader, mininterval=2 )
for x, _ in pbar: # x: images
optim.zero_grad()
x = x.to(device)
# perturb data
# 给当前的图片增加噪声
noise = torch.randn_like(x) # 随机生成噪声
t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) # 随机生成timestep
x_pert = perturb_input(x, t, noise) # 增加噪声扰动
# use network to recover noise
# 使用网络去预测噪声
pred_noise = nn_model(x_pert, t / timesteps)
# loss is mean squared error between the predicted and true noise
# 使用MSE计算损失
loss = F.mse_loss(pred_noise, noise)
loss.backward()
optim.step()
# save model periodically
# 按照周期保存模型
if ep%4==0 or ep == int(n_epoch-1):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
print('saved model at ' + save_dir + f"model_{ep}.pth")
# construct DDPM noise schedule
# 构建DDPM的计算模式
# 定义 \beta_t ,表示从零到一的若干均匀分布的小数,有几个时间步骤,就有几个
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
# 计算\alpha_t 得值
a_t = 1 - b_t
# 这里是通过取对数,然后再去指数,来避免小数连乘的溢出。
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# 确保x_0的连续性
ab_t[0] = 1
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
# 祛除模型预测的噪声,并且添加一些额外的噪声,避免过拟合
def denoise_add_noise(x, t, pred_noise, z=None):
# 重参数化,实现对特定复杂分布的采样,z是从高斯分布进行的正常采样
if z is None:
z = torch.randn_like(x)
noise = b_t.sqrt()[t] * z
# 公式的前半项,x是当前timestep的情况,这里完全是按照公式进行推倒的
mean = ((x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) # 减去预测噪声
/ a_t[t].sqrt())
# 增加额外的噪声,防止过拟合
return mean + noise
上述方法完全是按照对应的公示进行展开的,看过了推导之后,发现对于整个公式的理解更加明确。
下述为整体的采样过程
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# sample with context using standard algorithm
@torch.no_grad()
def sample_ddpm_context(n_sample, context, save_rate=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0
# 和之前一样,就是增加了对应的上下文信息
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t, ctx)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1
# 下述为根据采样公式写出的采样函数
# t是当前的状态数量
# t-prev是根据当前状态t,需要预测prev向前的内容
def denoise_ddim(x, t, t_prev, pred_noise):
ab = ab_t[t]
ab_prev = ab_t[t_prev]
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
dir_xt = (1 - ab_prev).sqrt() * pred_noise
return x0_pred + dir_xt
# 具体调用采样过程
# sample quickly using DDIM
@torch.no_grad()
def sample_ddim(n_sample, n=20):
# x_T ~ N(0, 1), sample initial noise
samples = torch.randn(n_sample, 3, height, height).to(device)
# array to keep track of generated steps for plotting
intermediate = []
step_size = timesteps // n
for i in range(timesteps, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
eps = nn_model(samples, t) # predict noise e_(x_t,t)
samples = denoise_ddim(samples, i, i - step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate