BN、CBN、CmBN 的对比与总结

BN、CBN、CmBN 的对比与总结

最近看到了关于 Yolo 系列 trick 的总结文章 【Make YOLO Great Again】YOLOv1-v7全系列大解析(Tricks篇),其中提到了 YoloV4 中使用了 CmBN,这是对 CBN 的改进,可以较好的适应小 batch 的情形。论文中给出了一个简要的对比图:

BN、CBN、CmBN 的对比与总结_第1张图片

这里结合此图对 BN 和其两种改进策略进行说明。所以需要注意的是,这里存在两个 batch 相关的概念:

  • batch:指代与 BN 层的统计量 实际想要相对应的数据池,也就是图片样本数。
  • mini-batch:由于整个 batch 独立计算时,受到资源限制可能不现实,于是需要将 batch 拆分成数个 mini-batch,每个 mini-batch 单独计算后汇总得到整个 batch 的统计量。从而归一化特征。

我们日常在分割或者检测中使用 BN 时,此时如果不使用特殊的设定,那么 batch 与 mini-batch 是一样的。CBN 和 CmBN 所做的就是如何使用多个独立的 mini-batch 的数据获得一个近似于更大 batch 的统计量以提升学习效果。

CBN 与 CmBN

CmBN(Cross mini-Batch Normalization)是 CBN(Cross-Iteration Batch Normalization)的修改版。

CBN 主要用来解决在 Batch-Size 较小时,BN 的效果不佳问题。CBN 连续利用多个迭代的数据来变相扩大 batch size 从而改进模型的效果。这种用前几个 iteration 计算好的统计量来计算当前迭代的 BN 统计量的方法会有一个问题:过去的 BN 参数是由过去的网络参数计算出来的特征而得到的,而本轮迭代中计算 BN 时,它们的模型参数其实已经过时了

假定 batch=4*mini batch,CBN 在 t t t 次迭代:

  • 模型基于之前的梯度被更新。此时的 BN 的仿射参数也是最新的。
  • 除了本次迭代的统计量,也会使用通过补偿后的前 3 次迭代得到的统计量。这 4 次的统计量会被一起用来得到近似于整个窗口的近似 batch 的 BN 的统计量。
  • 使用得到的近似统计量归一化特征。
  • 使用当前版本的仿射参数放缩和偏移。

CmBN 是基于 CBN 改进的,按照论文的图示的意思,主要的差异在于从滑动窗口变为固定窗口。每个 batch 中的统计不会使用 batch 之前的迭代的信息,仅会累积该窗口内的 4 次迭代以用于最后一次迭代的更新。这一策略基本与梯度累积策略仍有不同,梯度累加仅仅累加了梯度,但是前面的图中明显可以看到 BN 的统计量实际上也累积了起来,而图 4 中的展现的 BN 似乎更像是梯度累积。

CBN 的实现

# https://github.com/Howal/Cross-iterationBatchNorm/blob/f6d35301789c96e52699a9cbc8d2de8681547770/mmdet/models/utils/CBN.py#L74
def forward(self, input, weight):
    # deal with wight and grad of self.pre_dxdw!
    self._check_input_dim(input)
    y = input.transpose(0, 1)
    return_shape = y.shape
    y = y.contiguous().view(input.size(1), -1)

    # burnin
    if self.training and self.burnin > 0:
        self.iter_count += 1
        self._update_buffer_num()

    if self.buffer_num > 0 and self.training and input.requires_grad:  # some layers are frozen!
        # cal current batch mu and sigma
        cur_mu = y.mean(dim=1)
        cur_meanx2 = torch.pow(y, 2).mean(dim=1)
        cur_sigma2 = y.var(dim=1)
        # cal dmu/dw dsigma2/dw
        dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]
        dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]
        # update cur_mu and cur_sigma2 with pres
        mu_all = torch.stack([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])
        meanx2_all = torch.stack([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])
        sigma2_all = meanx2_all - torch.pow(mu_all, 2)

        # with considering count
        re_mu_all = mu_all.clone()
        re_meanx2_all = meanx2_all.clone()
        re_mu_all[sigma2_all < 0] = 0
        re_meanx2_all[sigma2_all < 0] = 0
        count = (sigma2_all >= 0).sum(dim=0).float()
        mu = re_mu_all.sum(dim=0) / count
        sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)

        self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]
        self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]
        self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]
        self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]

        tmp_weight = torch.zeros_like(weight.data)
        tmp_weight.copy_(weight.data)
        self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]

    else:
        x = y
        mu = x.mean(dim=1)
        cur_mu = mu
        sigma2 = x.var(dim=1)
        cur_sigma2 = sigma2

    if not self.training or self.FROZEN:
        y = y - self.running_mean.view(-1, 1)
        # TODO: outside **0.5?
        if self.out_p:
            y = y / (self.running_var.view(-1, 1) + self.eps)**.5
        else:
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        
    else:
        if self.track_running_stats is True:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * cur_mu
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * cur_sigma2
        y = y - mu.view(-1, 1)
        # TODO: outside **0.5?
        if self.out_p:
            y = y / (sigma2.view(-1, 1) + self.eps)**.5
        else:
            y = y / (sigma2.view(-1, 1)**.5 + self.eps)

    y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
    return y.view(return_shape).transpose(0, 1)

你可能感兴趣的:(深度学习,深度学习,目标检测)