上文吐槽BN部分讲的太烂,2018年果然更新了这一部分,slides里加了好多内容,详见Lecture 6的slides第54到61页,以及Lecture 7的slides第11到18页,这里结合着原始论文和作业,把BN及其几个变种好好总结一下。
前面的作业中已经见识到了,weight初始化时方差的调校真的是很麻烦,小了梯度消失不学习,大了梯度爆炸没法学习。
即使开始初始化的很好,随着学习的深入,网络的加深,每一层的方差已经不再受控;另外,特别是对于刚开始的几层,方差上稍微的变化,都会在forward prop时逐级放大的传递下去。
作业中只是三五层的小网络,要是几十上百层的网络,可以想象学习几乎是不可能的。
既然每一层输入的方差会产生如此多的问题,这就产生了第一个想法,何不将每一层的输入直接标准化为0均值单位方差。由于NN的train多是基于mini-batch的,所以这里标准化也是基于mini-batch。
输入x是包含N个sample的mini-batch,每个sample有D个feature。对每个feature进行标准化,即:
但是但是但是,这里武断的使输入均值为0,方差为1真的是最好的选择么?不一定。如果不是最好的选择,
设为多少是最好的选择呢?不知道。不知道的话怎么办呢?
那就让NN自己去学习一个最好的去呗。所以才有了下一步:
其中, γ γ 和 β β 是要学习的参数,将输入的均值和方差从(0,1)又拉到了 (γ,β) ( γ , β ) 。
所以,通常说起来BN是一层,但是我认为,BN是两层:Normalization Layer和Shift Layer,这两层是紧密相连,不可分割的。其中,Normalization Layer将输入的均值和方差标准化为(0,1),Shift Layer又将其拉到 (γ,β) ( γ , β ) 。这里, (γ,β) ( γ , β ) 和其他的weight、bias一样,都是通过backprop算梯度,然后再用SGD等方法更新学习得到。
好,这里强调两个问题,也是我第一遍看paper时的疑惑,也是2017年视频中那位小姑娘讲课时犯的错误:
To address this, we make sure that the transformation inserted in the network can represent the identity transform.
input:
x: (N, D)
intermediates:
mean: (1, D)
mean = np.mean(x, axis=0)
var: (1, D)
var = np.var(x, axis=0)
xhat: (N, D)
xhat = (x - mean) / (np.sqrt(var + eps))
learnable params:
gamma: (1, D)
beta: (1, D)
输出:
y = gamma * xhat + beta
dbeta = np.sum(dout, axis=0)
这里就不赘述了。
dgamma = np.sum(xhat * dout, axis=0)
这里还是把过程写一下吧
先画出forward和backward的计算图,如图所示。forward的代码如下:
x_mean = 1 / N * np.sum(x, axis=0)
x_mean_0 = x - x_mean
x_mean_0_sqr = x_mean_0 ** 2
x_var = 1 / N * np.sum(x_mean_0_sqr, axis=0)
x_std = np.sqrt(x_var + eps)
inv_x_std = 1 / x_std
x_hat = x_mean_0 * inv_x_std
out = gamma * x_hat + beta
cache = (x_mean, x_mean_0, x_mean_0_sqr, x_var, x_std, inv_x_std, x_hat, gamma, eps)
这里需要注意的是 1. 尽量将每一步化成最简单的加、乘操作,并且将每一步等号左边的项全部cache起来。这样做的目的是减少backprop时的计算量,但是相应的存贮量就会增加。所以说NN的内存需求要远远大于weights和bias的数目。 2. 计算mean是,用 1/N * np.sum(),不要用np.mean(),否则在backprop的时候会把 1/N 漏掉。 如果forward的每一步计算分解的足够细的话,backprop可以很清楚:
# out = gamma * x_hat + beta
# (N,D) (D,) (N,D) (D,)
Dx_hat = dout * gamma
# x_hat = x_mean_0 * inv_x_std
# (N,D) (N,D) (D,)
Dx_mean_0 = Dx_hat * (inv_x_std)
Dinv_x_std = np.sum(Dx_hat * (x_mean_0), axis=0)
# inv_x_std = 1 / x_std
# (D,) (D,)
Dx_std = Dinv_x_std * (- x_std ** (-2))
# x_std = np.sqrt(x_var + eps)
# (D,) (D,)
Dx_var = Dx_std * (0.5 * (x_var + eps) ** (-0.5))
# x_var = 1 / N * np.sum(x_mean_0_sqr, axis=0)
# (D,) (N,D)
Dx_mean_0_sqr = Dx_var * (1 / N * np.ones_like(x_mean_0_sqr))
# x_mean_0_sqr = x_mean_0 ** 2
# (N,D) (N,D)
Dx_mean_0 += Dx_mean_0_sqr * (2 * x_mean_0)
# x_mean_0 = x - x_mean
# (N,D) (N,D) (D,)
Dx = Dx_mean_0 * (1)
Dx_mean = - np.sum(Dx_mean_0, axis=0)
# x_mean = 1 / N * np.sum(x, axis=0)
# (D,) (N,D)
Dx += Dx_mean * (1 / N * np.ones_like(x_hat))
dx = Dx
这里要注意的是: 1. 一定要把每一步计算中每一项的维度搞清楚写下来。注意这一步:
# x_hat = x_mean_0 * inv_x_std
# (N,D) (N,D) (D,)
Dx_mean_0 = Dx_hat * (inv_x_std)
Dinv_x_std = np.sum(Dx_hat * (x_mean_0), axis=0)
因为numpy在进行矩阵运算的时候会进行自动的broadcast,所以这里 inv_x_std 实际是形如 (D,),但是计算是会broadcast成为(N, D)。仅从式子看的话,很容易误写为:
Dinv_x_std = Dx_hat * (x_mean_0)
这时如果进行一下维度分析,会发现 Dinv_x_std 显然要形如 (D,),但是右侧点积的结果形如 (N, D),显然要对 axis=0 进行 sum。同理还有这一行:
# x_mean_0 = x - x_mean
# (N,D) (N,D) (D,)
Dx = Dx_mean_0 * (1)
Dx_mean = np.sum(Dx_mean_0 * (-1), axis=0)
# x_mean = 1 / N * np.sum(x, axis=0)
# (D,) (N,D)
Dx += Dx_mean * (1 / N * np.ones_like(x_hat))
Dx_mean_0 = Dx_hat * (inv_x_std)
Dx_mean_0 += Dx_mean_0_sqr * (2 * x_mean_0)
第二种方法的公式推导实在是太繁了,我再也不想写第二遍了。先来个计算图:
first_part = gamma * inv_x_std / N
second_part = N * dout
third_part = np.sum(dout, axis=0)
forth_part = inv_x_std ** 2 * x_mean_0 * np.sum(dout * x_mean_0, axis=0)
dx = first_part * (second_part - third_part - forth_part)
Inline Question 1:
Describe the results of this experiment. How does the scale of weight initialization affect models with/without batch normalization differently, and why?
BN层的加入,大大降低了训练过程对weight初始化的依赖。
Inline Question 2:
Describe the results of this experiment. What does this imply about the relationship between batch normalization and batch size? Why is this relationship observed?
BN层的加入使得训练收敛的更快,acc更高,但是对test影响不是很大。
另外,如果batch size太小,反而不如没有BN。
Inline Question 3:
Which of these data preprocessing steps is analogous to batch normalization, and which is analogous to layer normalization?
1. Scaling each image in the dataset, so that the RGB channels for each row of pixels within an image sums up to 1.
2. Scaling each image in the dataset, so that the RGB channels for all pixels within an image sums up to 1.
3. Subtracting the mean image of the dataset from each image in the dataset.
4. Setting all RGB values to either 0 or 1 depending on a given threshold.
1、2类似于layer norm,3类似于batch norm。
Layer norm 和 batch norm 很像,都是用在FC层,只不过 batch norm 在 X 的 sample 方向取均值和方差,即将形如 (N, D) 的 X 取为形如 (1, D) 的均值和方差;而 layer norm 是在 X 的feature方向取均值和方差,即将形如 (N, D) 的 X 取为形如 (N, 1) 的均值和方差。因此,方便记忆的话,可以将 batch norm 记为 N norm 或者 axis=0 norm,将 layer norm 记为 D norm 或者 axis=1 norm。
另外,layer norm 在 train 和 test 时计算方法均相同,而不用像 batch norm 那样需要记录一个 running mean 和 running var。
这里还要特别注意的一点是,两者的 gamma 和 beta 都是形如 (1, D) 的。
Layer norm 的实现同 batch norm 相似,只需要将输入转置,就可调用 batch norm 来实现。
Inline Question 4:
When is layer normalization likely to not work well, and why?
1. Using it in a very deep network
2. Having a very small dimension of features
3. Having a high regularization term
从结果看,layer norm 的效果不是很好,特别是当 batch size 很小时。
但是对深层NN来说,layer norm 可以加快训练速度。
注意,reg只是施加于 weights 上的,并不施加于 norm 的参数 gamma 和 beta。如果 reg 很大的话,那么 affine 层的 weights 会被拉向0,输出值的大小也会减小,因此会减小 norm 层的作用。
这里所谓的 spatial batch Normalization ,实际上就是 BN 的CNN banben版本。只不过 BN 是将形如 (N, D) 的 X 取为形如 (1, D) 的均值和方差;而 SBN 是将形如 (N, C, H, W) 的 X 取形如 (1, C, 1, 1) 的均值和方差,需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。
作业里的 SBN,只需要将输入形如 (N, C, H, W) 的 X,首先转置为 (N, H, W, C),然后 reshape 为 (N * H * W, C),调用 BN 的 forward 和 backward 进行计算,最后将结果再转置会原来的形状即可。
Instance normalization 是将形如 (N, C, H, W) 的 X 取形如 (N, C, 1, 1) 的均值和方差,需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。
Group Normalization 是 layer normalization 在 CNN 中的改进版本。
Layer normalization 据说在 CNN 中的效果不是很好,所以要加以改进。在 CNN 中,layer normalization 是一次对所有的 C 取均值和方差,而 group normalization 是将 C 分为几组,每次仅在组内取均值和方差。
需要训练的 gamma 和 beta 也是形如 (1, C, 1, 1) 的。
作业中的也没什么好说的,只要把输入适当的 split 就可以了。
end