cs231n'18: Assignment 2 | Batch Normalization

Assignment 2 | Batch Normalization

上文吐槽BN部分讲的太烂,2018年果然更新了这一部分,slides里加了好多内容,详见Lecture 6的slides第54到61页,以及Lecture 7的slides第11到18页,这里结合着原始论文和作业,把BN及其几个变种好好总结一下。

Batch Normalization

Train

前面的作业中已经见识到了,weight初始化时方差的调校真的是很麻烦,小了梯度消失不学习,大了梯度爆炸没法学习。
即使开始初始化的很好,随着学习的深入,网络的加深,每一层的方差已经不再受控;另外,特别是对于刚开始的几层,方差上稍微的变化,都会在forward prop时逐级放大的传递下去。
作业中只是三五层的小网络,要是几十上百层的网络,可以想象学习几乎是不可能的。

既然每一层输入的方差会产生如此多的问题,这就产生了第一个想法,何不将每一层的输入直接标准化为0均值单位方差。由于NN的train多是基于mini-batch的,所以这里标准化也是基于mini-batch。

输入x是包含N个sample的mini-batch,每个sample有D个feature。对每个feature进行标准化,即:

μjσ2j=1Ni=1Nxi,j=1Ni=1N(xi,jμj)2 μ j = 1 N ∑ i = 1 N x i , j σ j 2 = 1 N ∑ i = 1 N ( x i , j − μ j ) 2

标准化后的输出为:
x^=xμjσ2j+ϵ x ^ = x − μ j σ j 2 + ϵ

但是但是但是,这里武断的使输入均值为0,方差为1真的是最好的选择么?不一定。如果不是最好的选择,
设为多少是最好的选择呢?不知道。不知道的话怎么办呢?
那就让NN自己去学习一个最好的去呗。所以才有了下一步:

y=γx^+β y = γ ⋅ x ^ + β

其中, γ γ β β 是要学习的参数,将输入的均值和方差从(0,1)又拉到了 (γ,β) ( γ , β )

所以,通常说起来BN是一层,但是我认为,BN是两层:Normalization Layer和Shift Layer,这两层是紧密相连,不可分割的。其中,Normalization Layer将输入的均值和方差标准化为(0,1),Shift Layer又将其拉到 (γ,β) ( γ , β ) 。这里, (γ,β) ( γ , β ) 和其他的weight、bias一样,都是通过backprop算梯度,然后再用SGD等方法更新学习得到。

好,这里强调两个问题,也是我第一遍看paper时的疑惑,也是2017年视频中那位小姑娘讲课时犯的错误:

  1. 一提到BN层的作用,马上想到的是:将输入映射为0均值单位方差的高斯分布。错!首先它不一定是高斯分布,可以是任意的分布,BN仅仅改变均值方差,不改变分布。其次,均值方差不是(0,1),而是 (γ,β) ( γ , β ) 。说(0, 1)的是忘记了shift这一层。
  2. 原文中有一句,还打了斜体:

    To address this, we make sure that the transformation inserted in the network can represent the identity transform.


当时看的时候就不明白,既然费半天劲减均值除方差,怎么这里又要 “represent the identity transform”? 而且加上后边的 (γ,β) ( γ , β ) 操作,就更看不懂了。其实这里漏看了一个 “can” 。既然 (γ,β) ( γ , β ) 是学习来的,它们当然可以是原始输入的均值方差了,所以BN有表达一个identity transform的能力,而不是必须要表达一个identity transform。 总结一下:
input:
      x: (N, D)
intermediates:
      mean: (1, D)  
          mean = np.mean(x, axis=0)
      var: (1, D)
          var = np.var(x, axis=0)
      xhat: (N, D)
          xhat = (x - mean) / (np.sqrt(var + eps))
learnable params:
      gamma: (1, D)
      beta: (1, D)
输出:
      y = gamma * xhat + beta

Test

在test时,就没有mini-batch可用来算 μ μ σ2 σ 2 了,此时常用的方法是在train的过程中记录一个 μ μ σ2 σ 2 的滑动均值在test的时候使用。 BN通常放在FC/Conv之后,ReLU之前。

Backprop

BN的backprop是这次作业的难点,还要用两种方法做,这里一步一步尽量详细地把推导过程写出来。

dβ d β

dβ d β 用维度分析法:
y=γx^+β y = γ ⋅ x ^ + β
其中 y y 形如(N, D), γ γ β β 形如(D,), x^ x ^ 形如(N, D),所以 dβ d β 必然为:
dbeta = np.sum(dout, axis=0)
这里就不赘述了。

dγ d γ

其实 dγ d γ 也可以用维度分析法得到, dy d y dx^ d x ^ 都形如(N, D),而 dγ d γ 形如(D,),显然 dγ d γ 应为:
dgamma = np.sum(xhat * dout, axis=0)
这里还是把过程写一下吧
y11y21yN1y12y22...yN2............y1Dy2DyND=[γ1γ2...γD]x11x21xN1x12x22...xN2............x1Dx2DxND [ y 11 y 12 . . . y 1 D y 21 y 22 . . . y 2 D . . . . . . y N 1 y N 2 . . . y N D ] = [ γ 1 γ 2 . . . γ D ] ⋅ [ x 11 x 12 . . . x 1 D x 21 x 22 . . . x 2 D . . . . . . x N 1 x N 2 . . . x N D ]
展开可得:
y11=γ1x11,y21=γ1x21,y12=γ2x12,y22=γ1x22,...... y 11 = γ 1 ⋅ x 11 , y 12 = γ 2 ⋅ x 12 , . . . y 21 = γ 1 ⋅ x 21 , y 22 = γ 1 ⋅ x 22 , . . .
由此可得:
Lγq=Lyyγq=i,jLyijyijγq ∂ L ∂ γ q = ∂ L ∂ y ⋅ ∂ y ∂ γ q = ∑ i , j ∂ L ∂ y i j ⋅ ∂ y i j ∂ γ q
而仅当 j=q j = q 时有
yijγq=xiq ∂ y i j ∂ γ q = x i q
其余均为0,故:
Lγq=i=1NLyiqyiqγq=i=1Nxiqdyiq ∂ L ∂ γ q = ∑ i = 1 N ∂ L ∂ y i q ⋅ ∂ y i q ∂ γ q = ∑ i = 1 N x i q ⋅ d y i q

dx d x :第一种方法


先画出forward和backward的计算图,如图所示。forward的代码如下:

x_mean = 1 / N * np.sum(x, axis=0)
x_mean_0 = x - x_mean
x_mean_0_sqr = x_mean_0 ** 2
x_var = 1 / N * np.sum(x_mean_0_sqr, axis=0)
x_std = np.sqrt(x_var + eps)
inv_x_std = 1 / x_std
x_hat = x_mean_0 * inv_x_std

out = gamma * x_hat + beta
cache = (x_mean, x_mean_0, x_mean_0_sqr, x_var, x_std, inv_x_std, x_hat, gamma, eps)
这里需要注意的是 1. 尽量将每一步化成最简单的加、乘操作,并且将每一步等号左边的项全部cache起来。这样做的目的是减少backprop时的计算量,但是相应的存贮量就会增加。所以说NN的内存需求要远远大于weights和bias的数目。 2. 计算mean是,用 1/N * np.sum(),不要用np.mean(),否则在backprop的时候会把 1/N 漏掉。 如果forward的每一步计算分解的足够细的话,backprop可以很清楚:
# out = gamma * x_hat + beta
# (N,D) (D,)    (N,D)   (D,)
Dx_hat = dout * gamma

# x_hat = x_mean_0 * inv_x_std
# (N,D)   (N,D)      (D,)
Dx_mean_0 = Dx_hat * (inv_x_std)
Dinv_x_std = np.sum(Dx_hat * (x_mean_0), axis=0)

# inv_x_std = 1 / x_std
# (D,)            (D,)
Dx_std = Dinv_x_std * (- x_std ** (-2))

# x_std = np.sqrt(x_var + eps)
# (D,)           (D,)
Dx_var = Dx_std * (0.5 * (x_var + eps) ** (-0.5))

# x_var = 1 / N * np.sum(x_mean_0_sqr, axis=0)
# (D,)                   (N,D)
Dx_mean_0_sqr = Dx_var * (1 / N * np.ones_like(x_mean_0_sqr))

# x_mean_0_sqr = x_mean_0 ** 2
# (N,D)          (N,D)
Dx_mean_0 += Dx_mean_0_sqr * (2 * x_mean_0)

# x_mean_0 = x - x_mean
# (N,D)     (N,D) (D,)
Dx = Dx_mean_0 * (1)
Dx_mean = - np.sum(Dx_mean_0, axis=0)

# x_mean = 1 / N * np.sum(x, axis=0)
# (D,)                   (N,D)
Dx += Dx_mean * (1 / N * np.ones_like(x_hat))

dx = Dx
这里要注意的是: 1. 一定要把每一步计算中每一项的维度搞清楚写下来。注意这一步:
# x_hat = x_mean_0 * inv_x_std
# (N,D)   (N,D)      (D,)
Dx_mean_0 = Dx_hat * (inv_x_std)
Dinv_x_std = np.sum(Dx_hat * (x_mean_0), axis=0)
因为numpy在进行矩阵运算的时候会进行自动的broadcast,所以这里 inv_x_std 实际是形如 (D,),但是计算是会broadcast成为(N, D)。仅从式子看的话,很容易误写为:
Dinv_x_std = Dx_hat * (x_mean_0)
这时如果进行一下维度分析,会发现 Dinv_x_std 显然要形如 (D,),但是右侧点积的结果形如 (N, D),显然要对 axis=0 进行 sum。同理还有这一行:
# x_mean_0 = x - x_mean
# (N,D)     (N,D) (D,)
Dx = Dx_mean_0 * (1)
Dx_mean = np.sum(Dx_mean_0 * (-1), axis=0)
  1. y=ixi y = ∑ i x i 的求导,这里
    y=x= [y1,y2,...,yD]x11x21xN1x12x22...xN2............x1Dx2DxND y =   [ y 1 , y 2 , . . . , y D ] x = [ x 11 x 12 . . . x 1 D x 21 x 22 . . . x 2 D . . . . . . x N 1 x N 2 . . . x N D ]

    其中
    y1y2=1N(x11+x21+...+xN1)=1N(x12+x22+...+xN2)... y 1 = 1 N ( x 11 + x 21 + . . . + x N 1 ) y 2 = 1 N ( x 12 + x 22 + . . . + x N 2 ) . . .

    所以
    dx11=Lyyx11=iLyiyix11=Ly1y1x11=dy11N d x 11 = ∂ L ∂ y ⋅ ∂ y ∂ x 11 = ∑ i ∂ L ∂ y i ⋅ ∂ y i ∂ x 11 = ∂ L ∂ y 1 ⋅ ∂ y 1 ∂ x 11 = d y 1 ⋅ 1 N

    综上:
    dx=1Ndy1dy1dy1dy2dy2...dy2............dyDdyDdyD=1Ndy11111...1............111N×D d x = 1 N ⋅ [ d y 1 d y 2 . . . d y D d y 1 d y 2 . . . d y D . . . . . . d y 1 d y 2 . . . d y D ] = 1 N ⋅ d y ⋅ [ 1 1 . . . 1 1 1 . . . 1 . . . . . . 1 1 . . . 1 ] N × D
# x_mean = 1 / N * np.sum(x, axis=0)
# (D,)                   (N,D)
Dx += Dx_mean * (1 / N * np.ones_like(x_hat))
  1. 注意到backprop时 Dx_mean_0 两次出现在等式左边,这说明在计算图中有两条路径通向 Dx_mean_0,这两条路径的结果要相加,所以第二次出现时要用 +=:
Dx_mean_0 = Dx_hat * (inv_x_std)
Dx_mean_0 += Dx_mean_0_sqr * (2 * x_mean_0)

dx d x :第二种方法

第二种方法的公式推导实在是太繁了,我再也不想写第二遍了。先来个计算图:

xx^yL x → x ^ → y → L

中间参数分别为:
doutyx^μσ2=Ly=γx^+β=xμσ2+ϵ=1Nn=1Nxn=1Nn=1N(xnμ)2 d o u t = ∂ L ∂ y y = γ ⋅ x ^ + β x ^ = x − μ σ 2 + ϵ μ = 1 N ∑ n = 1 N x n σ 2 = 1 N ∑ n = 1 N ( x n − μ ) 2

计算对 xij x i j 的导数:
Lxij=n,dLyndyndxij=n,dLyndyndxnd^xnd^xij ∂ L ∂ x i j = ∑ n , d ∂ L ∂ y n d ⋅ ∂ y n d ∂ x i j = ∑ n , d ∂ L ∂ y n d ⋅ ∂ y n d ∂ x n d ^ ⋅ ∂ x n d ^ ∂ x i j

其中:
yndxnd^μdσ2dyndxnd^=γdxnd^+βd=xndμdσ2d+ϵ=1Nn=1Nxnd=1Nn=1N(xndμd)2=γd y n d = γ d ⋅ x n d ^ + β d x n d ^ = x n d − μ d σ d 2 + ϵ μ d = 1 N ∑ n = 1 N x n d σ d 2 = 1 N ∑ n = 1 N ( x n d − μ d ) 2 ∂ y n d ∂ x n d ^ = γ d

下面的工作就是要计算 xnd^xij ∂ x n d ^ ∂ x i j :
xnd^xij=xijxndμdσ2d+ϵ=(σ2d+ϵ)12xij(xndμd)+(xndμd)xij(σ2d+ϵ)12=(σ2d+ϵ)12xij(xndμd)12(σ2d+ϵ)32(xndμd)σ2dxij ∂ x n d ^ ∂ x i j = ∂ ∂ x i j ( x n d − μ d σ d 2 + ϵ ) = ( σ d 2 + ϵ ) − 1 2 ⋅ ∂ ∂ x i j ( x n d − μ d ) + ( x n d − μ d ) ⋅ ∂ ∂ x i j ( σ d 2 + ϵ ) − 1 2 = ( σ d 2 + ϵ ) − 1 2 ⋅ ∂ ∂ x i j ( x n d − μ d ) − 1 2 ( σ d 2 + ϵ ) − 3 2 ( x n d − μ d ) ⋅ ∂ σ d 2 ∂ x i j

下面分别计算,首先:
xij(xndμd)=xij(xnd1Nt=1Nxtd)=xndxij1Nxij(t=1Nxtd) ∂ ∂ x i j ( x n d − μ d ) = ∂ ∂ x i j ( x n d − 1 N ∑ t = 1 N x t d ) = ∂ x n d ∂ x i j − 1 N ∂ ∂ x i j ( ∑ t = 1 N x t d )

第一项,当且仅当 n=i n = i , d=j d = j 时不为0,第二项中仅有 d=j d = j 项不为0,故:
xij(xndμd)=δn,iδd,j1Nδd,j ∂ ∂ x i j ( x n d − μ d ) = δ n , i ⋅ δ d , j − 1 N δ d , j

接着计算:
σ2dxij=xij(1Nn=1N(xndμd)2)=1Nn=1Nxij((xndμd)2)=2Nn=1N(xndμd)xij(xndμd)=2Nn=1N(xndμd)(δn,iδd,j1Nδd,j)=2Nn=1N(xndμd)δn,iδd,j2N2n=1N(xndμd)δd,j ∂ σ d 2 ∂ x i j = ∂ ∂ x i j ( 1 N ∑ n = 1 N ( x n d − μ d ) 2 ) = 1 N ∑ n = 1 N ∂ ∂ x i j ( ( x n d − μ d ) 2 ) = 2 N ∑ n = 1 N ( x n d − μ d ) ∂ ∂ x i j ( x n d − μ d ) = 2 N ∑ n = 1 N ( x n d − μ d ) ⋅ ( δ n , i ⋅ δ d , j − 1 N δ d , j ) = 2 N ∑ n = 1 N ( x n d − μ d ) ⋅ δ n , i ⋅ δ d , j − 2 N 2 ∑ n = 1 N ( x n d − μ d ) ⋅ δ d , j

第一项中,仅有 n=i n = i 一项不为0:
n=1N(xndμd)δn,iδd,j=(xidμd)δd,j ∑ n = 1 N ( x n d − μ d ) ⋅ δ n , i ⋅ δ d , j = ( x i d − μ d ) ⋅ δ d , j

第二项:
n=1N(xndμd)=n=1NxndNμd ∑ n = 1 N ( x n d − μ d ) = ∑ n = 1 N x n d − N ⋅ μ d

μd=1NNn=1xnd μ d = 1 N ∑ n = 1 N x n d ,因此上式为0。
所以:
σ2dxij=2N(xidμd)δd,j ∂ σ d 2 ∂ x i j = 2 N ( x i d − μ d ) ⋅ δ d , j

综上:
xnd^xij=(σ2d+ϵ)12xij(xndμd)12(σ2d+ϵ)32(xndμd)σ2dxij=(σ2d+ϵ)12(δn,iδd,j1Nδd,j)1N(σ2d+ϵ)32(xndμd)(xidμd)δd,j ∂ x n d ^ ∂ x i j = ( σ d 2 + ϵ ) − 1 2 ⋅ ∂ ∂ x i j ( x n d − μ d ) − 1 2 ( σ d 2 + ϵ ) − 3 2 ( x n d − μ d ) ⋅ ∂ σ d 2 ∂ x i j = ( σ d 2 + ϵ ) − 1 2 ⋅ ( δ n , i ⋅ δ d , j − 1 N δ d , j ) − 1 N ( σ d 2 + ϵ ) − 3 2 ( x n d − μ d ) ( x i d − μ d ) ⋅ δ d , j

最后,计算对 xij x i j 的导数
Lxij=n,dLyndγdxnd^xij=n,dγdLynd(σ2d+ϵ)12(δn,iδd,j1Nδd,j)1Nn,dγdLynd(σ2d+ϵ)32(xndμd)(xidμd)δd,j ∂ L ∂ x i j = ∑ n , d ∂ L ∂ y n d ⋅ γ d ⋅ ∂ x n d ^ ∂ x i j = ∑ n , d γ d ⋅ ∂ L ∂ y n d ⋅ ( σ d 2 + ϵ ) − 1 2 ⋅ ( δ n , i ⋅ δ d , j − 1 N δ d , j ) − 1 N ∑ n , d γ d ⋅ ∂ L ∂ y n d ⋅ ( σ d 2 + ϵ ) − 3 2 ( x n d − μ d ) ( x i d − μ d ) ⋅ δ d , j

第一项
n,dγdLynd(σ2d+ϵ)12(δn,iδd,j1Nδd,j)=nγjLynj(σ2j+ϵ)12(δn,i1N)=γjLyij(σ2j+ϵ)121NnγjLynj(σ2j+ϵ)12=γj(σ2j+ϵ)12(Lyij1NnLynj) ∑ n , d γ d ⋅ ∂ L ∂ y n d ⋅ ( σ d 2 + ϵ ) − 1 2 ⋅ ( δ n , i ⋅ δ d , j − 1 N δ d , j ) = ∑ n γ j ⋅ ∂ L ∂ y n j ⋅ ( σ j 2 + ϵ ) − 1 2 ⋅ ( δ n , i − 1 N ) = γ j ⋅ ∂ L ∂ y i j ⋅ ( σ j 2 + ϵ ) − 1 2 − 1 N ∑ n γ j ⋅ ∂ L ∂ y n j ⋅ ( σ j 2 + ϵ ) − 1 2 = γ j ⋅ ( σ j 2 + ϵ ) − 1 2 ( ∂ L ∂ y i j − 1 N ∑ n ∂ L ∂ y n j )

第二项
1Nn,dγdLynd(σ2d+ϵ)32(xndμd)(xidμd)δd,j=1NnγjLynj(σ2j+ϵ)32(xnjμj)(xijμj)=1Nγj(σ2j+ϵ)32(xijμj)nLynj(xnjμj) − 1 N ∑ n , d γ d ⋅ ∂ L ∂ y n d ⋅ ( σ d 2 + ϵ ) − 3 2 ( x n d − μ d ) ( x i d − μ d ) ⋅ δ d , j = − 1 N ∑ n γ j ⋅ ∂ L ∂ y n j ⋅ ( σ j 2 + ϵ ) − 3 2 ( x n j − μ j ) ( x i j − μ j ) = − 1 N γ j ( σ j 2 + ϵ ) − 3 2 ( x i j − μ j ) ∑ n ∂ L ∂ y n j ( x n j − μ j )

最后的结果为
Lx=γN(σ2+ϵ)12(NLynLyn(σ2+ϵ)1(xμ)nLyn(xnμ)) ∂ L ∂ x = γ N ( σ 2 + ϵ ) − 1 2 ( N ∂ L ∂ y − ∑ n ∂ L ∂ y n − ( σ 2 + ϵ ) − 1 ( x − μ ) ∑ n ∂ L ∂ y n ( x n − μ ) )

代码如下

first_part = gamma * inv_x_std / N
second_part = N * dout
third_part = np.sum(dout, axis=0)
forth_part = inv_x_std ** 2 * x_mean_0 * np.sum(dout * x_mean_0, axis=0)

dx = first_part * (second_part - third_part - forth_part)

Inline Question 1:

Describe the results of this experiment. How does the scale of weight initialization affect models with/without batch normalization differently, and why?

BN层的加入,大大降低了训练过程对weight初始化的依赖。

Inline Question 2:

Describe the results of this experiment. What does this imply about the relationship between batch normalization and batch size? Why is this relationship observed?

BN层的加入使得训练收敛的更快,acc更高,但是对test影响不是很大。
另外,如果batch size太小,反而不如没有BN。

Inline Question 3:

Which of these data preprocessing steps is analogous to batch normalization, and which is analogous to layer normalization?
1. Scaling each image in the dataset, so that the RGB channels for each row of pixels within an image sums up to 1.
2. Scaling each image in the dataset, so that the RGB channels for all pixels within an image sums up to 1.
3. Subtracting the mean image of the dataset from each image in the dataset.
4. Setting all RGB values to either 0 or 1 depending on a given threshold.

1、2类似于layer norm,3类似于batch norm。

Layer normalization

Layer norm 和 batch norm 很像,都是用在FC层,只不过 batch norm 在 X 的 sample 方向取均值和方差,即将形如 (N, D) 的 X 取为形如 (1, D) 的均值和方差;而 layer norm 是在 X 的feature方向取均值和方差,即将形如 (N, D) 的 X 取为形如 (N, 1) 的均值和方差。因此,方便记忆的话,可以将 batch norm 记为 N norm 或者 axis=0 norm,将 layer norm 记为 D norm 或者 axis=1 norm。

另外,layer norm 在 train 和 test 时计算方法均相同,而不用像 batch norm 那样需要记录一个 running mean 和 running var。

这里还要特别注意的一点是,两者的 gamma 和 beta 都是形如 (1, D) 的。

Layer norm 的实现同 batch norm 相似,只需要将输入转置,就可调用 batch norm 来实现。

Inline Question 4:

When is layer normalization likely to not work well, and why?
1. Using it in a very deep network
2. Having a very small dimension of features
3. Having a high regularization term

从结果看,layer norm 的效果不是很好,特别是当 batch size 很小时。
但是对深层NN来说,layer norm 可以加快训练速度。
注意,reg只是施加于 weights 上的,并不施加于 norm 的参数 gamma 和 beta。如果 reg 很大的话,那么 affine 层的 weights 会被拉向0,输出值的大小也会减小,因此会减小 norm 层的作用。

Spatial batch Normalization

这里所谓的 spatial batch Normalization ,实际上就是 BN 的CNN banben版本。只不过 BN 是将形如 (N, D) 的 X 取为形如 (1, D) 的均值和方差;而 SBN 是将形如 (N, C, H, W) 的 X 取形如 (1, C, 1, 1) 的均值和方差,需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。

作业里的 SBN,只需要将输入形如 (N, C, H, W) 的 X,首先转置为 (N, H, W, C),然后 reshape 为 (N * H * W, C),调用 BN 的 forward 和 backward 进行计算,最后将结果再转置会原来的形状即可。

Instance Normalization

Instance normalization 是将形如 (N, C, H, W) 的 X 取形如 (N, C, 1, 1) 的均值和方差,需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。

Group Normalization

Group Normalization 是 layer normalization 在 CNN 中的改进版本。
Layer normalization 据说在 CNN 中的效果不是很好,所以要加以改进。在 CNN 中,layer normalization 是一次对所有的 C 取均值和方差,而 group normalization 是将 C 分为几组,每次仅在组内取均值和方差。
需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。

作业中的也没什么好说的,只要把输入适当的 split 就可以了。

end

你可能感兴趣的:(cs231n,Stanford,cs231n'18,课程及作业详细解读)