扩散模型近期在图像生成领域很火, 没想到很快就被用在了检测上. 打算对这篇论文做一个笔记.
论文地址: 论文
代码: 代码
首先介绍什么是扩散模型. 我们考虑生成任务, 即encoder-decoder形式的模型, encoder提取输入的抽象信息, 并尝试在decoder中恢复出来. 扩散模型就是这一类中的方法, 其灵感由热力学而来, 基本做法是在输入中逐步加噪, 并学会如何在噪声中恢复出输入. 在加噪和去噪的过程中都假设为Markov过程.
假定原始数据服从分布 x 0 ∼ q ( x ) \textbf{x}_0\sim q(\textbf{x}) x0∼q(x), 现在我们逐步对其加噪, 加入的是高斯噪声. 对于每一步加噪, 我们希望将分布 q q q逐渐向高斯过程靠近, 也即让 q ( x t ∣ x t − 1 ) = N q(\textbf{x}_t|\textbf{x}_{t-1})=\mathcal{N} q(xt∣xt−1)=N. 在每一步, 我们假定高斯分布的均值与过去的值 x t − 1 \textbf{x}_{t-1} xt−1有关, 而协方差为固定值(对角阵):
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\textbf{x}_t|\textbf{x}_{t-1})=\mathcal{N}(\textbf{x}_t;\sqrt{1-\beta_t}\textbf{x}_{t-1}, \beta_t \textbf{I}) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
其中 β t \beta_t βt为一小正数, 在0~1之间.
注意我们假定加噪过程为Markov过程, 因此当前状态 x t \textbf{x}_t xt只假定与上一状态 x t − 1 \textbf{x}_{t-1} xt−1有关.
因此, 当前时刻 x t \textbf{x}_t xt是由前一时刻 x t − 1 \textbf{x}_{t-1} xt−1决定的正态分布, 其均值为 1 − β t x t − 1 \sqrt{1-\beta_t}\textbf{x}_{t-1} 1−βtxt−1, 方差为 β t I \beta_t \textbf{I} βtI. 为了表示 x t \textbf{x}_t xt, 这里我们使用一下**重参数化(Re-parametrization)**技巧. 重参数化是说, 如果我们从一个高斯分布中取样, 也等效于从标准高斯分布中取样, 只不过是加上均值, 以及乘以标准差. 这是因为一个 x ∼ N ( μ , σ 2 ) x\sim\mathcal{N}(\mu, \sigma^2) x∼N(μ,σ2)的高斯分布可以等价于 μ + σ ϵ \mu+\sigma \epsilon μ+σϵ, 其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim\mathcal{N}(0,1) ϵ∼N(0,1).
由高斯分布性质立即得到.
因此, x t \textbf{x}_t xt可以表示为:
x t = 1 − β t x t − 1 + β t ϵ t − 1 \textbf{x}_t=\sqrt{1-\beta_t} \textbf{x}_{t-1}+\sqrt{\beta_t}\epsilon_{t-1} xt=1−βtxt−1+βtϵt−1
其中 ϵ t − 1 ∼ N ( 0 , I ) \epsilon_{t-1}\sim\mathcal{N}(0,I) ϵt−1∼N(0,I). 为了表示方便, 令 1 − β t = α t \sqrt{1-\beta_t}=\sqrt{\alpha_t} 1−βt=αt, 将上式递归展开, 我们有:
x t = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + β t ϵ t − 1 \textbf{x}_t=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\textbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2} )+\sqrt{\beta_t}\epsilon_{t-1} xt=αt(αt−1xt−2+1−αt−1ϵt−2)+βtϵt−1
我们注意到后面两项可以合并为一个新的高斯分布, 其均值为0, 方差为 1 − α t α t − 1 1-\alpha_t\alpha_{t-1} 1−αtαt−1, 按此规律展开, 我们得到:
x t = Π i α i x 0 + 1 − Π i α i ϵ , ϵ ∼ N ( 0 , I ) \textbf{x}_t=\sqrt{\Pi_i\alpha_i}\textbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαix0+1−Πiαiϵ, ϵ∼N(0,I)
所以, 我们可以直接从 x 0 \textbf{x}_0 x0得到 x t \textbf{x}_t xt的分布:
q ( x t ∣ x 0 ) = N ( x t ; Π i α i x 0 , 1 − Π i α i I ) q(\textbf{x}_t|\textbf{x}_0)=\mathcal{N}(\textbf{x}_t;\sqrt{\Pi_i\alpha_i}\textbf{x}_0, \sqrt{1-\Pi_i\alpha_i} \textbf{I}) q(xt∣x0)=N(xt;Πiαix0,1−ΠiαiI)
所以, 随着时间的增加, x t \textbf{x}_t xt会越来越趋向于标准正态分布. 以上就是加噪的过程.
我们假定, 在加噪的正向过程中最后的结果已经近似为标准高斯分布 x T ∼ N ( 0 , I ) \textbf{x}_T \sim \mathcal{N}(0,\textbf{I}) xT∼N(0,I). 我们现在希望从加噪后的高斯分布中恢复出原来的信号, 即, 通过逐步计算 q ( x t − 1 ∣ x t ) q(\textbf{x}_{t-1}|\textbf{x}_t) q(xt−1∣xt)恢复. 然而, 如果这样计算的话, 需要从整个数据集中采样, 计算量非常大(有可能是因为类似于高斯混合模型的过程), 为此, 我们希望学习出一个模型 p θ p_\theta pθ来学习恢复过程中的条件概率:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 , μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(\textbf{x}_{t-1}|\textbf{x}_t)=\mathcal{N}(\textbf{x}_{t-1},\mu_\theta(\textbf{x}_t,t), \Sigma_\theta(\textbf{x}_t,t)) pθ(xt−1∣xt)=N(xt−1,μθ(xt,t),Σθ(xt,t))
我们需要做的是让分布 p ( x t − 1 ∣ x t ) p(\textbf{x}_{t-1}|\textbf{x}_t) p(xt−1∣xt)尽可能与 q ( x t − 1 ∣ x t ) q(\textbf{x}_{t-1}|\textbf{x}_t) q(xt−1∣xt)接近.
我们很难计算 q ( x t − 1 ∣ x t ) q(\textbf{x}_{t-1}|\textbf{x}_t) q(xt−1∣xt), 但可以考察以 x 0 为 条 件 的 以 下 概 率 \textbf{x}_0为条件的以下概率 x0为条件的以下概率:
q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0)
根据Bays公式, 有:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t , x t − 1 , ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0)=\frac{q(\textbf{x}_t,\textbf{x}_{t-1},|\textbf{x}_0)}{q(\textbf{x}_t|\textbf{x}_0)}=q(\textbf{x}_t|\textbf{x}_{t-1}, \textbf{x}_0)\frac{q(\textbf{x}_{t-1}|\textbf{x}_0)}{q(\textbf{x}_t|\textbf{x}_0)} q(xt−1∣xt,x0)=q(xt∣x0)q(xt,xt−1,∣x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)
扩散过程为Markov过程, 因此有:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0)=q(\textbf{x}_t|\textbf{x}_{t-1})\frac{q(\textbf{x}_{t-1}|\textbf{x}_0)}{q(\textbf{x}_t|\textbf{x}_0)} q(xt−1∣xt,x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)
代入高斯分布表达式, 并凑出均值和方差(整理成 exp { 1 2 σ 2 ( x t − μ ) 2 } \exp\{\frac{1}{2\sigma^2}(x_t-\mu)^2\} exp{2σ21(xt−μ)2}的形式), 我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0)的均值为:
μ = α t ( 1 − Π i T − 1 α i ) 1 − Π i T α i x t + Π i T − 1 α i ( 1 − α t ) 1 − Π i T α i x 0 \mu=\frac{\sqrt{\alpha_t}(1-\Pi_i^{T-1}\alpha_i)}{1-\Pi_i^{T}\alpha_i}\textbf{x}_{t}+\frac{\sqrt{\Pi_i^{T-1}\alpha_i}(1-\alpha_t)}{1-\Pi_i^{T}\alpha_i}\textbf{x}_{0} μ=1−ΠiTαiαt(1−ΠiT−1αi)xt+1−ΠiTαiΠiT−1αi(1−αt)x0
根据前面的重参数化技巧, 有 x t = Π i α i x 0 + 1 − Π i α i ϵ t , ϵ t \textbf{x}_t=\sqrt{\Pi_i\alpha_i}\textbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon_t, ~~\epsilon_t xt=Πiαix0+1−Πiαiϵt, ϵt为网络在这一步预测的高斯噪声, 代入上式得到:
μ = 1 α t ( x t − 1 − α t 1 − Π i T α i ϵ ) \mu=\frac{1}{\sqrt{\alpha_t}}(\textbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\Pi_i^{T}\alpha_i}}\epsilon) μ=αt1(xt−1−ΠiTαi1−αtϵ)
方差为:
Σ = 1 − Π i T − 1 α i 1 − Π i T α i \Sigma=\frac{1-\Pi_i^{T-1}\alpha_i}{1-\Pi_i^{T}\alpha_i} Σ=1−ΠiTαi1−ΠiT−1αi
所以
q ( x t − 1 ∣ x t , x 0 ) ∼ N ( μ , Σ ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) \sim \mathcal{N}(\mu, \Sigma) q(xt−1∣xt,x0)∼N(μ,Σ)
所以, 逆扩散的过程为: 根据网络从 x t \textbf{x}_t xt预测的噪声 ϵ t \epsilon_t ϵt计算出均值与方差, 进而得到 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0), 作为 p θ ( x t − 1 ∣ x t ) p_\theta(\textbf{x}_{t-1}|\textbf{x}_t) pθ(xt−1∣xt)的近似, 如此得到 x t − 1 \textbf{x}_{t-1} xt−1, 再根据 x t − 1 \textbf{x}_{t-1} xt−1预测出下一步的噪声 ϵ t − 1 \epsilon_{t-1} ϵt−1, 如此往复, 如下图所示(图源知乎)
我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0)的目的是: 让网络学习的 p p p与 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0)尽量接近, 或说用 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt−1∣xt,x0)指导 p p p的训练(可认为最小化二者之间的散度).
具体损失函数再补充.
DiffusionDet的思想非常直接, 既然目标检测是要准确地定位边界框的位置, 那么利用Diffusion Model的强大噪声恢复(学习)能力就可以优化检测的结果. 整体框架如下:
上图中, Image Encoder(为ResNet或Swin Transformer)提取图像的特征, 然后Detection Decoder接受噪声化的边界框, 并恢复边界框的初始值, 同时预测类别. 整体来说, 需要学习一个网络 f θ f_\theta fθ, 从 z T z_T zT中恢复出 z 0 z_0 z0, 其中 z z z为边界框. 损失函数即为恢复的值与初始值的差的2-范数:
L t r a i n = 1 2 ∣ ∣ f θ ( z t , t ) − z 0 ∣ ∣ 2 \mathcal{L}_{train}=\frac{1}{2}||f_\theta (z_t,t)-z_0||^2 Ltrain=21∣∣fθ(zt,t)−z0∣∣2
如上图所示, 为了减少计算量, Diffusion Model从原始图片提取的高级特征中学习. Image Encoder就是提取图像特征的, 作者采用了ResNet和SwinTransformer.
而Detection Decoder接受加噪的bbox和特征图, 并返回恢复的bbox.
训练过程的每次迭代大致分为四步:
伪代码:
训练过程中有几个细节:
推理过程大致分为三步:
伪代码:
在推理过程中值得注意的是bbox更新机制. 由于输入的是固定数量的随机框, 在训练阶段我们也是加入了随机框来使数目一样, 因此输出的有些是对应于GT的bbox, 有些则是随机的. 如果把随机的再一起喂到下一步, 作者说这样就破坏了原本的分布, 因此对于每一步预测的框, 将置信度过低的舍弃, 并以新的随机框补充.
首先看一下./diffusiondet/detector.py
中的DiffusionDet
类, 其是该论文的核心代码. 其中的forward
函数:
def forward(self, batched_inputs, do_postprocess=True):
images, images_whwh = self.preprocess_image(batched_inputs) # 预处理 归一化&填充
if isinstance(images, (list, torch.Tensor)):
images = nested_tensor_from_tensor_list(images)
# Feature Extraction.
src = self.backbone(images.tensor) # Encoder 提取各级特征
features = list()
for f in self.in_features:
feature = src[f]
features.append(feature)
# Prepare Proposals.
if not self.training: # 如果是推理阶段
results = self.ddim_sample(batched_inputs, features, images_whwh, images) # 从T时刻至0时刻 逐步采样恢复
return results
if self.training: # 训练阶段
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets, x_boxes, noises, t = self.prepare_targets(gt_instances) # prepare_targets: 对GT框逐步加噪
t = t.squeeze(-1)
x_boxes = x_boxes * images_whwh[:, None, :]
outputs_class, outputs_coord = self.head(features, x_boxes, t, None) # 经过RCNNhead 预测类别和bbox
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.deep_supervision:
output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
loss_dict = self.criterion(output, targets) # 计算loss
weight_dict = self.criterion.weight_dict
for k in loss_dict.keys():
if k in weight_dict:
loss_dict[k] *= weight_dict[k]
return loss_dict
可以看到, 里面还有两个重点的self.prepare_targets
(训练过程中的加噪)和self.ddim_sample
(推理过程中的采样)
def prepare_targets(self, targets):
new_targets = []
diffused_boxes = []
noises = []
ts = []
for targets_per_image in targets:
target = {}
h, w = targets_per_image.image_size
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
gt_classes = targets_per_image.gt_classes
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes) # 以上预处理真值框
d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes) # 核心部分 计算加噪后的框
diffused_boxes.append(d_boxes)
noises.append(d_noise)
ts.append(d_t)
target["labels"] = gt_classes.to(self.device)
target["boxes"] = gt_boxes.to(self.device)
target["boxes_xyxy"] = targets_per_image.gt_boxes.tensor.to(self.device)
target["image_size_xyxy"] = image_size_xyxy.to(self.device)
image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
target["image_size_xyxy_tgt"] = image_size_xyxy_tgt.to(self.device)
target["area"] = targets_per_image.gt_boxes.area().to(self.device)
new_targets.append(target) # target为蕴含大小、类别等信息的真值
# 返回真值、加噪后的框、噪声和步长
return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
其中的加噪过程在self.prepare_diffusion_concat(gt_boxes)
, 我们可以看到:
def prepare_diffusion_concat(self, gt_boxes):
"""
:param gt_boxes: (cx, cy, w, h), normalized
:param num_proposals:
"""
t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long() # 确定随机步长
noise = torch.randn(self.num_proposals, 4, device=self.device) # 产生标准正态分布
num_gt = gt_boxes.shape[0] # gt框数目
if not num_gt: # generate fake gt boxes if empty gt boxes
gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
num_gt = 1
if num_gt < self.num_proposals: # 如果gt框比预设的固定数目小 则随机再填充一些框
box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
device=self.device) / 6. + 0.5 # 3sigma = 1/2 --> sigma: 1/6
box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
elif num_gt > self.num_proposals: # 如果比预设数目多 就随机抹掉一些GT框
select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
random.shuffle(select_mask)
x_start = gt_boxes[select_mask]
else:
x_start = gt_boxes
x_start = (x_start * 2. - 1.) * self.scale
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise) # 前向加噪过程
x = torch.clamp(x, min=-1 * self.scale, max=self.scale) # 限制范围
x = ((x / self.scale) + 1) / 2.
diff_boxes = box_cxcywh_to_xyxy(x)
return diff_boxes, noise, t
最后再来看看推理阶段的self.ddim_sample
函数:
@torch.no_grad()
def ddim_sample(self, batched_inputs, backbone_feats, images_whwh, images, clip_denoised=True, do_postprocess=True):
batch = images_whwh.shape[0]
shape = (batch, self.num_proposals, 4)
total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist())) # 时间为倒序 从T到0
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device=self.device) # 产生标准高斯分布bboxs
ensemble_score, ensemble_label, ensemble_coord = [], [], []
x_start = None
for time, time_next in time_pairs: # 相邻时间两步计算
time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long)
self_cond = x_start if self.self_condition else None
# 预测的噪声、x_0和类别与坐标
preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond,
self_cond, clip_x_start=clip_denoised)
pred_noise, x_start = preds.pred_noise, preds.pred_x_start
if self.box_renewal: # filter Box reneral机制 将置信度低的边界框用随机框替换
score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
threshold = 0.5
score_per_image = torch.sigmoid(score_per_image)
value, _ = torch.max(score_per_image, -1, keepdim=False)
keep_idx = value > threshold
num_remain = torch.sum(keep_idx)
pred_noise = pred_noise[:, keep_idx, :]
x_start = x_start[:, keep_idx, :]
img = img[:, keep_idx, :]
if time_next < 0:
img = x_start
continue
# 获取\alpha_i的连乘值
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img) # 标准高斯分布中采样
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise # 通过预测的噪声 计算恢复结果
if self.box_renewal: # filter
# replenish with randn boxes
img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
outputs_coord[-1],
images.image_sizes)
ensemble_score.append(scores_per_image)
ensemble_label.append(labels_per_image)
ensemble_coord.append(box_pred_per_image)
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image = torch.cat(ensemble_coord, dim=0)
scores_per_image = torch.cat(ensemble_score, dim=0)
labels_per_image = torch.cat(ensemble_label, dim=0)
if self.use_nms:
keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
box_pred_per_image = box_pred_per_image[keep]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
result = Instances(images.image_sizes[0])
result.pred_boxes = Boxes(box_pred_per_image)
result.scores = scores_per_image
result.pred_classes = labels_per_image
results = [result]
else:
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
box_cls = output["pred_logits"]
box_pred = output["pred_boxes"]
results = self.inference(box_cls, box_pred, images.image_sizes)
if do_postprocess: # 后处理
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results