【cs231n】Batchnorm及其反向传播

文章目录

    • BatchNormalization
        • 反向传播
    • 其他Normalization方法
        • LayerNormalization
        • InstanceNormalization
        • GroupNormalization

神经网络中有很多层的叠加,数据经过每一层后,其分布会发生变化,给下一层的训练带来麻烦,这一现象称为Internal Covariate Shift。在bn之前一般通过减小学习率、初始化权重、以及细致训练策略来解决。

BatchNormalization

BatchNormalization就是将每一层数据拉回均值为0,方差为1的正太分布上。同时为了保证其还有学习到的特征,再对数据进行缩放和平移。假设输入为 x x x,大小为(N,D),则算法如下:

  1. 沿着通道D计算均值: μ B = 1 n ∑ i = 1 n x i \mu _ { \mathcal { B } } = \frac { 1 } { n } \sum _ { i = 1 } ^ { n } x _ { i } μB=n1i=1nxi
  2. 沿通道D计算方差: σ B 2 = 1 n ∑ i = 1 n ( x i − μ B ) 2 \sigma _ { \mathcal { B } } ^ { 2 }= \frac { 1 } { n } \sum _ { i = 1 } ^ { n } \left( x _ { i } - \mu _ { \mathcal { B } } \right) ^ { 2 } σB2=n1i=1n(xiμB)2
  3. 将数据拉回正态分布: x ^ i = x i − μ B σ B 2 + ϵ \widehat { x } _ { i } = \frac { x _ { i } - \mu _ { \mathcal { B } } } { \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } } x i=σB2+ϵ xiμB
  4. 进行平移和缩放: y i = γ x ^ i + β y _ { i } = \gamma \widehat { x } _ { i } + \beta yi=γx i+β

在训练时,batchnorm的正向传播如上所示,但在测试时,由于batchsize=1,即m=1,因此测试时使用的均值和方差直接由训练时计算的均值和方差的滑动平均值来代替。下面为代码。

def batchnorm_forward(x, gamma, beta, bn_param):
    """
    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features
    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

    out, cache = None, None
    if mode == 'train':
        sample_mean = np.mean(x,axis=0,keepdims=True)
        sample_var = np.var(x,axis=0,keepdims=True)
        sample_sqrtvar = np.sqrt(sample_var+eps)
        x_norm = (x-sample_mean)/sample_sqrtvar
        out = x_norm*gamma+beta
        cache = (x,x_norm,gamma,beta,eps,sample_mean,sample_var,sample_sqrtvar)
        running_mean = momentum * running_mean + (1 - momentum) * sample_mean
        running_var = momentum * running_var + (1 - momentum) * sample_var
        
    elif mode == 'test':
      	x_norm = (x-running_mean)/np.sqrt(running_var+eps)
        out = x_norm*gamma+beta

    # 将滑动均值和滑动方差保存或更新
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

在卷积神经网络中,输入X的大小可能为(N,W,H,C),所以求均值和方差的过程就变为

sample_mean = np.mean(x,axis=(0,1,2), keepdims=True)
sample_var = np.var(x,axis=(0,1,2), keepdims=True)

反向传播

另外,在cs231n中较难的一点是batchnorm的反向传播算法。下面为batchnorm的计算图。我们要计算的是 ∂ L ∂ γ \frac { \partial L } { \partial \gamma } γL ∂ L ∂ β \frac { \partial L} { \partial \beta } βL ∂ L ∂ x \frac { \partial L} { \partial x } xL
【cs231n】Batchnorm及其反向传播_第1张图片
首先计算比较容易的 ∂ l ∂ γ \frac { \partial l } { \partial \gamma } γl ∂ l ∂ β \frac { \partial l } { \partial \beta } βl

  • ∂ L ∂ γ = ∑ i = 1 n ∂ L ∂ y i x ^ i \frac { \partial L } { \partial \gamma } = \sum _ { i = 1 } ^ { n } \frac { \partial L} { \partial y _ { i } } \widehat { x } _ { i } γL=i=1nyiLx i
  • ∂ L ∂ β = ∑ i = 1 n ∂ L ∂ y i \frac { \partial L} { \partial \beta } = \sum _ { i = 1 } ^ { n } \frac { \partial L } { \partial y _ { i } } βL=i=1nyiL

然后根据链式求导计算 ∂ L ∂ x \frac { \partial L} { \partial x } xL

  1. ∂ L ∂ x ^ = ∂ L ∂ y γ \frac { \partial L } { \partial \widehat { x } } = \frac { \partial L } { \partial y } \gamma x L=yLγ
  2. ∂ L ∂ ( x − μ ) 1 = ∂ L ∂ x ^ 1 σ B 2 + ϵ = ∂ L ∂ y γ σ B 2 + ϵ ∂ L ∂ 1 σ B 2 + ϵ = ∑ i = 1 N ∂ L ∂ y γ ∗ ( x − μ ) \frac { \partial L } { \partial { (x -\mu)_1} } = \frac { \partial L } { \partial \widehat { x } }\frac{1}{ \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } } =\frac { \partial L } { \partial y } \frac{\gamma}{ \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } }\quad\frac { \partial L } { \partial\frac{1}{ \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } } } =\sum _ { i = 1 } ^ { N } \frac { \partial L } { \partial y } \gamma *( x -\mu) (xμ)1L=x LσB2+ϵ 1=yLσB2+ϵ γσB2+ϵ 1L=i=1NyLγ(xμ)
  3. ∂ L ∂ σ 2 + ϵ = − ∂ L ∂ 1 σ + ϵ 1 σ 2 + ϵ \frac { \partial L } { \partial{ \sqrt { \sigma ^ { 2 } + \epsilon } } }=-\frac { \partial L } { \partial\frac{1}{ \sqrt { \sigma + \epsilon } } }\frac{1}{{ \sigma ^ { 2 } + \epsilon } } σ2+ϵ L=σ+ϵ 1Lσ2+ϵ1
  4. ∂ L ∂ ( σ 2 ) = 1 2 ∂ L ∂ σ 2 + ϵ ( σ 2 + ϵ ) − 1 2 = − 1 2 ∑ i = 1 N ∂ L ∂ y γ ∗ ( x − μ ) ( σ 2 + ϵ ) − 3 2 \frac { \partial L } { \partial{ (\sigma ^ { 2 } )} }=\frac{1}{2}\frac { \partial L } { \partial{ \sqrt { \sigma ^ { 2 } + \epsilon } } }( \sigma ^ { 2 } + \epsilon )^{-\frac{1}{2}}={-\frac{1}{2}}\sum _ { i = 1 } ^ { N } \frac { \partial L } { \partial y } \gamma *( x -\mu)( \sigma ^ { 2 } + \epsilon )^{-\frac{3}{2}} (σ2)L=21σ2+ϵ L(σ2+ϵ)21=21i=1NyLγ(xμ)(σ2+ϵ)23
  5. ∂ L ∂ ( x − μ ) 2 = 2 N [ 1 ⋯ 1 ⋮ ⋱ ⋮ 1 ⋯ 1 ] ( N × D ) ∂ L ∂ ( σ 2 ) ( x − μ ) \frac { \partial L } { \partial { (x -\mu)_2} }=\frac { 2 } { N } \left[ \begin{array} { c c c } { 1 } & { \cdots } & { 1 } \\ { \vdots } & { \ddots } & { \vdots } \\ { 1 } & { \cdots } & { 1 } \end{array} \right] ^ { ( N \times D ) }\frac { \partial L } { \partial{ (\sigma ^ { 2 } )} }(x-\mu) (xμ)2L=N21111(N×D)(σ2)L(xμ)
  6. ∂ L ∂ μ = ∂ L ∂ μ − γ σ B 2 + ϵ ∑ i = 1 N ∂ L ∂ y − 2 N ∂ L ∂ ( σ 2 ) ∑ i = 1 N ( x − μ ) \frac { \partial L } { \partial { \mu} }=\frac { \partial L } { \partial { \mu} }- \frac{\gamma}{ \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } }\sum _ { i = 1 } ^ { N }\frac { \partial L } { \partial y } -\frac { 2 } { N }\frac { \partial L } { \partial{ (\sigma ^ { 2 } )} }\sum _ { i = 1 } ^ { N }(x-\mu) μL=μLσB2+ϵ γi=1NyLN2(σ2)Li=1N(xμ)
  7. ∂ L ∂ x 1 = ∂ μ ∂ x = 1 N [ 1 ⋯ 1 ⋮ ⋱ ⋮ 1 ⋯ 1 ] ( N × D ) ∂ L ∂ μ \frac { \partial L } { \partial { x_1} } = \frac { \partial \mu } { \partial { x} }=\frac { 1 } { N } \left[ \begin{array} { c c c } { 1 } & { \cdots } & { 1 } \\ { \vdots } & { \ddots } & { \vdots } \\ { 1 } & { \cdots } & { 1 } \end{array} \right] ^ { ( N \times D ) }\frac { \partial L } { \partial{ \mu} } x1L=xμ=N11111(N×D)μL
  8. ∂ L ∂ x 2 = ∂ ( x − μ ) 1 ∂ x + ∂ ( x − μ ) 2 ∂ x = ∂ L ∂ y γ σ B 2 + ϵ + 2 N [ 1 ⋯ 1 ⋮ ⋱ ⋮ 1 ⋯ 1 ] ( N × D ) ∂ L ∂ ( σ 2 ) ( x − μ ) \frac { \partial L } { \partial { x_2} }=\frac { \partial { (x-\mu)}_1 } { \partial x }+\frac { \partial { (x-\mu)}_2 } { \partial x }=\frac { \partial L } { \partial y } \frac{\gamma}{ \sqrt { \sigma _ { \mathcal { B } } ^ { 2 } + \epsilon } }+\frac { 2 } { N } \left[ \begin{array} { c c c } { 1 } & { \cdots } & { 1 } \\ { \vdots } & { \ddots } & { \vdots } \\ { 1 } & { \cdots } & { 1 } \end{array} \right] ^ { ( N \times D ) }\frac { \partial L } { \partial{ (\sigma ^ { 2 } )} }(x-\mu) x2L=x(xμ)1+x(xμ)2=yLσB2+ϵ γ+N21111(N×D)(σ2)L(xμ)
  9. ∂ L ∂ x = ∂ L ∂ x 1 + ∂ L ∂ x 2 \frac { \partial L } { \partial { x} }=\frac { \partial L } { \partial { x_1} }+\frac { \partial L } { \partial { x_2} } xL=x1L+x2L

bn层反向传播的代码:

def batchnorm_backward_alt(dout, cache)
    dx, dgamma, dbeta = None, None, None

    N,D = dout.shape
    x,x_norm,gamma,beta,eps,sample_mean,sample_var,sample_sqrtvar = cache
    dbeta = np.sum(dout,axis=0)
    dgamma = np.sum(dout*x_norm, axis=0)
    dx_norm = dout * gamma
    dvar = np.sum((dx_norm * (x-sample_mean) * (-0.5) * np.power(sample_var+eps, -1.5)),axis = 0)  # (D,)
    dmean = np.sum(dx_norm * (-1) * np.power(sample_var + eps, -0.5), axis = 0)
    dmean += dvar * np.sum(-2 * (x - sample_mean), axis = 0) / N
    dx = dx_norm * np.power(sample_var + eps, -0.5) + dvar*2*(x - sample_mean) / N + dmean / N

    return dx, dgamma, dbeta

其他Normalization方法

【cs231n】Batchnorm及其反向传播_第2张图片

LayerNormalization

LayerNormalization与BatchNormalization的主要区别在于求均值时所选的通道不同。假设有输入X的大小为(N,W,H,C)在,则ln与bn的区别如图所示。bn是沿C求均值,求得大小为C,而ln是沿N求均值,求得大小为N。因此相比于bn,ln在batchsize=1时也可以起到作用。
ln和bn的主要区别如下:

  • bn中每个神经元输入(即通道)有不同的均值,一个batch中的不同样本有相同的均值
  • ln中每个神经元输入有相同的均值,不同样本间的均值不同。
def layernorm_forward(x, gamma, beta, ln_param):
    """
    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - ln_param: Dictionary with the following keys:
        - eps: Constant for numeric stability
    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    out, cache = None, None
    eps = ln_param.get('eps', 1e-5)
    
    sample_mean = np.mean(x,axis=1,keepdims=True)
    sample_var = np.var(x,axis=1,keepdims=True)
    sample_sqrtvar = np.sqrt(sample_var+eps)
    x_norm = (x-sample_mean)/sample_sqrtvar
    out = x_norm*gamma+beta
    cache = (x,x_norm,gamma,beta,eps,sample_mean,sample_var,sample_sqrtvar)
    
    return out, cache

在卷积神经网络中,输入X的大小可能为(N,W,H,C),所以求均值和方差的过程就变为

sample_mean = np.mean(x,axis=(1,2,3), keepdims=True)
sample_var = np.var(x,axis=(1,2,3), keepdims=True)

InstanceNormalization

InstanceNormalization是求H和W的两个维度的均值,强调对图像实例的归一化,在图像风格化中很有作用。其代码如下:

def instance_forward(x, gamma, beta, ln_param):
    out, cache = None, None
    eps = ln_param.get('eps', 1e-5)
    
   	sample_mean = np.mean(x,axis=(1,2), keepdims=True)
	sample_var = np.var(x,axis=(1,2), keepdims=True)
    sample_sqrtvar = np.sqrt(sample_var+eps)
    x_norm = (x-sample_mean)/sample_sqrtvar
    out = x_norm*gamma+beta
    cache = (x,x_norm,gamma,beta,eps,sample_mean,sample_var,sample_sqrtvar)
    
    return out, cache

GroupNormalization

前面提到,由于bn在小batchsize上的效果比较差,所以gn就是在通道C方向上分为几个group,在group内进行归一化。代码如下:

def group_forward(x, gamma, beta, ln_param):
    out, cache = None, None
    eps = ln_param.get('eps', 1e-5)
    x = np.reshape(x, (x.shape[0], x.shape[1], x.shape[2],G,x.shape[3]/G))
   	sample_mean = np.mean(x,axis=(1,2,4), keepdims=True)
	sample_var = np.var(x,axis=(1,2,4), keepdims=True)
    sample_sqrtvar = np.sqrt(sample_var+eps)
    x_norm = (x-sample_mean)/sample_sqrtvar
    out = x_norm*gamma+beta
    cache = (x,x_norm,gamma,beta,eps,sample_mean,sample_var,sample_sqrtvar)
    
    return out, cache

参考
https://arxiv.org/pdf/1502.03167.pdf
https://blog.csdn.net/qq_25737169/article/details/79048516
https://blog.csdn.net/liuxiao214/article/details/81037416
https://zhuanlan.zhihu.com/p/33173246
https://blog.csdn.net/xiaojiajia007/article/details/54924959(这个是译文,原文找不到)

你可能感兴趣的:(cs231n)