BatchNormalization

目录

Covariate Shift

Internal Covariate Shift

BatchNormalization

Q1:BN的原理

Q2:BN的作用

Q3:BN的缺陷

Q4:BN的均值、方差的计算维度

Q5:BN在训练和测试时有什么区别

Q6:BN的代码实现


Covariate Shift

机器学习中,一般会假设模型的输入数据分布时稳定的。如果这个假设不成立,即模型的输入数据的分布发生变化,则称为协变量偏移(Covariate Shift).例如:模型的训练数据和测试数据分布不一致;模型在训练过程中输入数据发生变化。

Internal Covariate Shift

在深层网络训练的过程中,由于网络中参数变化而引起内部节点数据分布发生变化这一过程被称作Internal Covariate Shift.深度神经网络涉及到很多层的叠加,而每一层的参数更新会导致高层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断地去重新适应底层的参数更新。

Covariate Shift时模型的输入数据的分布发生变化,Internal Covariate Shift时网络内部的节点的输入数据分布发生变化。

ICS带来的问题

  • 高层网络需要不断调整来适应输入数据分布的变化,导致网络学习速度的降低,使得学习的过程变得很不稳定
  • 网络前几层参数的更新,很可能使得后面层输入数据变得过大或者过小,从而陷入梯度饱和区域(比如sigmoid的饱和区),减缓网络收敛速度

在各种Normalization提出之前,解决上述问题的方法是使用较小的学习率(避免参数更新太快),精细的参数初始化,训练时间会很长。

BatchNormalization

BN的提出就是为了解决Interval Covariate Shift问题。BN的作用是确保网络的各层,即使参数发生了变化,其输入数据的分布也不会发生太大的变化,将其拉回到均值0,方差1的正态分布,从而避免ICS。

Q1:BN的原理

BN可以看作带参数的标准化,它有两个需要学习的参数γ和β,称为偏移因子和缩放因子。BN包括两个步骤:第一步相当于标准化,对每层的输入统计均值和方差,然后进行去均值方差标准化处理,使得每层的输入都是均值为0,方差为1.第二步是对规范化的数据进行线性变化,使用两个可以学习的参数γ和β。第一步的标准化虽然缓解了ICS问题,使得每层网络输入都变得稳定,但却导致数据的表达能力减弱,使得底层网络学习的信息丢失(如果激活函数使用的是sigmoid或者tanh,0均值的数据大部分落在激活函数的近似线性区域,没有利用上非线性区域,极大削弱了非线性表达能力);因此加入了两个可学习的参数的线性变换来恢复数据的表达能力。

具体的计算公式如下:

BatchNormalization_第1张图片

Q2:BN的作用

  • 加速网路训练。ICS问题会导致深层网络需要不断去适应底层网络参数的变化,因此训练速度会很慢,BN解决了ICS问题,因此可以加速网络训练。

  • 防止过拟合。BN由于每次在计算均值方差时是依靠一个batch来计算的,引入了随机性,可以缓解过拟合问题,可以用来代替dropout以及降低L2正则的权重系数。
  • 缓解梯度消失和梯度爆炸。BN使得每层的输入不会太大,因此不会梯度爆炸;每层输入绝对值不会太大,就不会落入sigmoid激活函数的饱和区域,从而缓解梯度消失。
  • 调参更容易。之前由于ICS问题的存在,一般会采用更小的学习率,为了防止过拟合也会尝试dropout以及L2正则,并且要求很精细的网络初始化。有了BN后,缓解了ICS问题,就可以使用较大的学习率,对初始化的要求也没有那么高了。

Q3:BN的缺陷

  • 受到Batch Size影响很大,如果batch size较小,每次训练计算的均值方差不具有代表性且不稳定,会降低模型效果。
  • BN比较难用到RNN这种序列模型中。因为BN是batch内计算均值,而句子之间没有很强的语义关系,句子内部有比较强的语义关系,所以句子之间算均值对其再标准化,这种方式效果会不好。

Q4:BN的均值、方差的计算维度

对于全连接层:输入维度是[N,C],在N上计算平均,γ和β的维度是C

对于卷积层:输入维度是[N,C,H,W],在N,H,W上计算平均,γ和β的维度是C。

Q5:BN在训练和测试时有什么区别

训练时,均值、方差分别是该批次内数据相应维度的均值和方差;测试时,均值、方差是基于训练时批次数据均值方差的无偏估计,公司如下:(即在训练时保存所有批次的均值方差,然后计算无偏估计)

BatchNormalization_第2张图片

在推荐过程中BN采用如下公式:

 

这里的E[x]就是我上面那个式子里面的μtest,Var[x]就是我上面式子里面的σ^2test。

这个式子和训练时:

是等价的,不过是做了一些变换。在实际运行的时候,按照这种变体可以减少计算量,为啥呢?因为对于隐藏节点来说:

 

 都是固定值,这样这两个值可以事先存起来,在推荐的时候直接用就行了,这样比原始的公式每一步都现成算少了除法的运算过程,如果隐藏节点个数多的话就会节省很多的计算量。

Q6:BN的代码实现

Batch Normalization里面有一个momentum参数,该参数作用于mean和variance的计算上,这里保留了历史batch里面的mean和variance值,即moving_mean和moving_variance,计算的是移动平均,将历史batch里的mean和variance的作用延续到当前batch。一般momentum的值为0.9,0.99等。多个batch后,即多个0.9连乘后,最早的batch的影响后变弱。

指数移动平均:指数移动平均是以指数式递减加权的移动平均。各数值的加权影响力随时间而指数式递减,越近期的数据加权影响力越重,但较旧的数据也给予一定的加权值。计算公式为:u_t=\beta u_{t-1}+(1-\beta)\theta_t.优点:当想要计算均值的时候,不用保留所有时刻的值。随着时间推移,遥远过去的历史的影响会越来越小。

所以这里要注意实际实现的时候,测试阶段的均值和方差是通过在训练阶段指数加权移动平均来统计均值和方差得到的

代码实现如下:

import torch
from torch import nn 
 
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        #判断是全连接层还是卷积层,2代表全连接层,样本数和特征数;4代表卷积层,批量数,通道数,高宽
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            #1*n*高*宽
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
 
    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)

关于BN常见的面试题目整理:

1.BN为什么效果?

讲BN的归一化带来的好处以及mini-batch mean 和mini-batch variance引入正则作用这两个方面

2.为什么BN归一化后还要有scale-shift操作?

这个在文中提到了

3.BN改变了数据分布,为什么效果反而会更好?

虽然会改变数据分布,但是数据之间的关联性是不会变的;由于有目标函数在,所以神经网络自己会朝着分布最优的方向去学习

4.BN用在什么地方

一般用在全连接层+BN+激活函数

5.对于什么激活函数,BN效果会明显?

对于sigmoid或者tanh激活函数,BN效果会好一些。

6.BN中在训练和测试时怎么用?

文中讲到了。

7.BN缺点

小样本时,效果不好,均值和方差是有偏的;在RNN中效果通常不好。其实文中也讲到了

[深度学习基础][面经]Batch Normalization - 知乎

Batch Normalization导读 - 知乎

整理学习之Batch Normalization(批标准化)_笨笨犬牙的博客-CSDN博客

你可能感兴趣的:(面试,机器学习,深度学习,人工智能)