[FRN] Filter Response Normalization

背景

BN依赖于Batch做归一化,在小批量上会出现性能退化;GN虽然通过将特征在Channel上分组来摆脱Batch的 依赖,但是在大批量上性能不如BN。

BN 到GN

机器学习最重要的任务

根据一些以观察到的证据来对感兴趣的位置变量进行估计和推测。

概率模型提高了一种描述框架,将学习任务归结于计算变量。

归一化处理有助于模型的优化。

BN

BN 通过计算batch中所有样本的每个channel上的均值和方差来进行归一化。

计算方式伪代码:

FRN计算步骤伪代码
  1. 计算在(B, H, W)维度上的均值和方差
  2. 在各个通道上进行标准归一化
  3. 对归一化的特征进行放缩 和平移 ,其中两个参数是可学习的

问题

  1. 训练时batch较大,预测时batch通常为1,造成训练和预测时均值 和方差 的计算分布不一致。

    BN的解决方案是 在训练时估计一个均值和方差量来作为测试时的归一化参数,一般对每次mini-batch的均值和方差进行指数加权平均来得到

  2. BN对batch的大小敏感,如果batch太小,模型性能会明显恶化,且受限于显存大小,当前很多模型的batch难以很大。

解决BN问题

1. 避免在batch维度归一化

由上述,我们知道如果避免在batch维度上进行归一化可以避免batch带来的问题。BN的两个主要问题 训练和与测试均值和方差计算分布不一致batch太小模型性能恶化 都是batch维度带来的,显然不在batch上进行归一化,上述问题就迎刃而解了。

基于这一观点,衍生出一系列方法:

Layer NormalizationInstance NormalizationGroup Normalization

LN,IN,GN,BN的区别
BN LN IN GN
处理维度 (B, H, W) (H, W, C) (H, W) (H, W, G)

GN在归一化时需要对C分组,即特征从 [B, H, W, C] 转换成 [B, H, W, G, C/G]

LN,IN,GN都没有在batch维度上进行归一化,所以不会有BN的问题。相比之下,GN更为常用。

GN 和 BN 性能对比
2. 降低训练和测试之间的不一致性

Batch Renomalization

限制训练过程中batch统计量的值范围

3. 多卡BN方法训练

相当于增大batch size。

FRN

FRN层包括 FRN归一化层FRN (Filter Response Normalization)激活层TLU (Threshold Linear Unit)

FRN不仅消除了训练时对batch的依赖,而且当batch size较大时性能由于BN。

FRN结构示意图

原理  FRN的操作是在 (H, W) 维度上的,即对每个样本的每个channel单独进行归一化,这里就是一个N维度的向量,所以没有对batch依赖的问题。FRN没有采样高斯分布进行归一化,而是除以的二次范数的平均值。这种归一化方式类似BN可以消除中间操作(卷积和非线性激活)带来的尺度问题,有助于模型训练。

  防止除0的小正常量。FRN 是在 H,W 两个维度归一化,一般情况下网络的特征图大小 较大,但有时候会出现 的情况。

对于特征图为 的情况, 就比较关键,不同的 正则化效果区别很大。当 值较小时,归一化相当于符号函数,这时候梯度几乎为0,严重影响模型训练;当 值较大时,曲线变得更圆滑,此时梯度有助于学习。对于这种情况,论文建议采用一个可学习的 。

不同eps的梯度对比

IN 也是在 H, W维度上进行归一化,但是会减去均值,对于 的情况归一化结果是 0,但FRN可以避免这个问题。

归一化之后同样需要进行缩放和平移变换,这里 和 也是可学习的参数:

FRN缺少减均值的操作,可能使得归一化的结果任意地偏移0,如果FRN之后是ReLU激活层,可能产生很多0值,这对于模型训练和性能是不利地。

为了解决这个问题,FRN之后采用阈值化的ReLU,即TLU:

其中 是可学习参数。

实验结果

实验结果

代码实现

class FilterResponseNormNd(nn.Module):
    def __init__(self, ndim, num_features, eps=1e-6, learnable_eps=False):
        assert ndim in [3,4,5], \
            'FilterResponseNorm only support 3d, 4d or 5d inputs'
        super(FilterResponseNormNd, self).__init__()
        shape = (1, num_features) + (1, ) * (ndim - 2)
        self.eps = nn.Parameter(torch.ones(*shape) * eps)
        if not learnable_eps:
            self.eps.required_grad_(False)
        self.gamma = nn.Parameter(torch.Tensor(*shape))
        self.beta = nn.Parameter(torch.Tensor(*shape))
        self.tau = nn.Parameter(torch.Tensor(*shape))
        self.reset_parameters()
    def forward(self, x):
        avg_dims = tuple(range(2, x.dim()))
        nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True)
        x = x * torch.rsqrt(nu2 + torch.abs(self.eps))
        return torch.max(self.gamma * x + self.bata, self.tau)
    def reset_parameters(self):
        nn.init.ones_(self.gamma)
        nn.init.zeros_(self.beta)
        nn.init.zeros_(self.tau)
    

你可能感兴趣的:([FRN] Filter Response Normalization)