Pre-Norm和Post-Norm之间的对比是一个“老生常谈“的问题,目前也没有一个比较好的结论解释清楚,当前比较明确的结论是:同一设置下,Pre-Norm结构往往更加容易训练,但最终效果不如Post-Norm。Pre Norm更容易训练好理解,因为它的恒等路径更突出,但为什么它效果反而没那么好呢?
Pre Norm和Post Norm的式子分别如下:
Pre Norm: x t + 1 = x t + F t ( Norm ( x t ) ) Post Norm: x t + 1 = Norm ( x t + F t ( x t ) ) \text{Pre Norm: } \quad \boldsymbol{x}_{t+1} = \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\\text{Post Norm: }\quad \boldsymbol{x}_{t+1} = \text{Norm}(\boldsymbol{x}_t + F_t(\boldsymbol{x}_t)) Pre Norm: xt+1=xt+Ft(Norm(xt))Post Norm: xt+1=Norm(xt+Ft(xt))
在Transformer中,这里的NormNorm主要指Layer Normalization,但在一般的模型中,它也可以是Batch Normalization、Instance Normalization等,相关结论本质上是通用的。
由上述公式所示:Pre Norm 和 Post Norm 的区别 Layer Norm 和 Residual connections 组合方式的不同。
在原始的 Transformers 论文中,使用的是 Post Norm,如下所示:
即每一层的输入先与 Attention 和 FFN 相加,然后才计算 Layer Norm。早期的很多模型都用的是 Post Norm(BERT):
x t + 1 , = Norm ( x l + A t t n ( x l ) ) x t + 1 = Norm ( x t + 1 , + FFN ( x t + 1 , ) ) x_{t+1}^, = \text{Norm}(x_l+Attn(x_l)) \\ x_{t+1} = \text{Norm}(x_{t+1}^,+\text{FFN}(x_{t+1}^,)) xt+1,=Norm(xl+Attn(xl))xt+1=Norm(xt+1,+FFN(xt+1,))
Post Norm 之所以这么设计,是把 Normalization 放在一个模块的最后,这样下一个模块接收到的总是归一化后的结果。这比较符合 Normalization 的初衷,就是为了降低梯度的方差。但是层层堆叠起来,从上图可以看出,深度学习的基建 ResNet 的结构其实被破坏了。
对于残差连接,如果 x x x 的方差(二阶矩同理)为 σ 1 2 \sigma_1^2 σ12 而 F ( x ) F(x) F(x) 的方差为 σ 2 2 \sigma_2^2 σ22,并且假设两者相互独立,那么 x + F ( x ) x+F(x) x+F(x) 方差为 σ 1 2 + σ 2 2 \sigma_1^2+\sigma_2^2 σ12+σ22。也就是说,残差会进一步放大方差,所以在残差后也使用Normalization可以稳定前向传播中的方差。但事实上已经严重削弱了残差的恒等分支,所以反而失去了残差“易于训练”的优点,通常要warmup并设置足够小的学习率才能使它收敛。
怎么理解这一点呢?假设初始状态下 x , F ( x ) x,F(x) x,F(x)的方差均为1,那么 x + F ( x ) x+F(x) x+F(x)的方差就是2,而Normalization操作负责将方差重新降为1,这就说明初始阶段Post Norm相当于:
x t + 1 = x t + F t ( x t ) 2 \begin{equation}x_{t+1} = \frac{x_t + F_t(x_t)}{\sqrt{2}}\end{equation} xt+1=2xt+Ft(xt)
递归下去,我们得到:
x l = x l − 1 2 + F l − 1 ( x l − 1 ) 2 = x l − 2 2 + F l − 2 ( x l − 2 ) 2 + F l − 1 ( x l − 1 ) 2 = ⋯ = x 0 2 l / 2 + F 0 ( x 0 ) 2 l / 2 + F 1 ( x 1 ) 2 ( l − 1 ) / 2 + F 2 ( x 2 ) 2 ( l − 2 ) / 2 + ⋯ + F l − 1 ( x l − 1 ) 2 1 / 2 \begin{equation}\begin{aligned}x_l =&\, \frac{x_{l-1}}{\sqrt{2}} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\=&\, \frac{x_{l-2}}{2} + \frac{F_{l-2}(x_{l-2})}{2} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\ =&\, \cdots \\ =&\,\frac{x_0}{2^{l/2}} + \frac{F_0(x_0)}{2^{l/2}} + \frac{F_1(x_1)}{2^{(l-1)/2}} + \frac{F_2(x_2)}{2^{(l-2)/2}} + \cdots + \frac{F_{l-1}(x_{l-1})}{2^{1/2}} \end{aligned}\end{equation} xl====2xl−1+2Fl−1(xl−1)2xl−2+2Fl−2(xl−2)+2Fl−1(xl−1)⋯2l/2x0+2l/2F0(x0)+2(l−1)/2F1(x1)+2(l−2)/2F2(x2)+⋯+21/2Fl−1(xl−1)
看到问题了没?本来残差的意思是给前面的层搞一条“绿色通道”,让梯度可以更直接地回传,但是在Post Norm中,这条“绿色通道”被严重削弱了,越靠近前面的通道反而权重越小,大模型中的层数一版很深,当l = 32时,前面的系数趋于0,残差“名存实亡”,因此还是不容易训练。相关的分析还可以参考论文《On Layer Normalization in the Transformer Architecture》。
一个针对性的改进称为Pre Norm,它的思想是“要用的时候才去标准化”,其形式为
x t + 1 = x t + F t ( Norm ( x t ) ) \begin{equation}x_{t+1} = x_t + F_t(\text{Norm}(x_t))\end{equation} xt+1=xt+Ft(Norm(xt))
类似地,迭代展开之后我们可以认为初始阶段有
x l = x 0 + F 0 ( x 0 ) + F 1 ( x 1 / 2 ) + F 2 ( x 2 / 3 ) + ⋯ + F l − 1 ( x l − 1 / l ) \begin{equation} x_l = x_0 + F_0(x_0) + F_1(x_1/\sqrt{2}) + F_2(x_2/\sqrt{3}) + \cdots + F_{l-1}(x_{l-1}/\sqrt{l})\end{equation} xl=x0+F0(x0)+F1(x1/2)+F2(x2/3)+⋯+Fl−1(xl−1/l)
这样一来,起码每一条残差通道都是平权的,残差的作用会比Post Norm更加明显,所以它也更好优化。当然,这样最后 x l x_l xl的方差将会很大,所以在接预测层之前xl也还要加个Normalization。
知乎上 @唐翔昊 给出的答案是:**Pre Norm的深度有“水分”!**也就是说,一个 L L L层的Pre Norm模型,其实际等效层数不如 L L L层的Post Norm模型,而层数少了导致效果变差了。
具体怎么理解呢?很简单,对于Pre Norm模型我们迭代得到:
x t + 1 = x t + F t ( Norm ( x t ) ) = x t − 1 + F t − 1 ( Norm ( x t − 1 ) ) + F t ( Norm ( x t ) ) = ⋯ = x 0 + F 0 ( Norm ( x 0 ) ) + ⋯ + F t − 1 ( Norm ( x t − 1 ) ) + F t ( Norm ( x t ) ) \begin{equation}\begin{aligned} \boldsymbol{x}_{t+1} =&\,\boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\ =&\, \boldsymbol{x}_{t-1} + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \\ =&\, \cdots \\ =&\, \boldsymbol{x}_0 + F_0 (\text{Norm}(\boldsymbol{x}_0)) + \cdots + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \end{aligned}\end{equation} xt+1====xt+Ft(Norm(xt))xt−1+Ft−1(Norm(xt−1))+Ft(Norm(xt))⋯x0+F0(Norm(x0))+⋯+Ft−1(Norm(xt−1))+Ft(Norm(xt))
其中每一项都是同一量级的,那么有 x t + 1 = O ( t + 1 ) \boldsymbol{x}_{t+1}=\mathcal{O}(t+1) xt+1=O(t+1),也就是说第 t + 1 t+1 t+1层跟第 t t t层的差别就相当于 t + 1 t+1 t+1与 t t t的差别,当 t t t较大时,两者的相对差别是很小的,因此:
F t ( Norm ( x t ) ) + F t + 1 ( Norm ( x t + 1 ) ) ≈ F t ( Norm ( x t ) ) + F t + 1 ( Norm ( x t ) ) = ( 1 1 ) ( F t F t + 1 ) ( Norm ( x t ) ) \begin{equation}\begin{aligned} &\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) \\ \approx&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) \\ =&\, \begin{pmatrix} 1 & 1\end{pmatrix}\begin{pmatrix} F_t \\ F_{t+1}\end{pmatrix}(\text{Norm}(\boldsymbol{x}_t)) \end{aligned}\end{equation} ≈=Ft(Norm(xt))+Ft+1(Norm(xt+1))Ft(Norm(xt))+Ft+1(Norm(xt))(11)(FtFt+1)(Norm(xt))
这个意思是说,当 t t t比较大时, x t , x t + 1 \boldsymbol{x}_t,\boldsymbol{x}_{t+1} xt,xt+1相差较小,所以 F t + 1 ( Norm ( x t + 1 ) ) F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) Ft+1(Norm(xt+1))与 F t + 1 ( Norm ( x t ) ) F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) Ft+1(Norm(xt))很接近,因此原本一个 t t t层的模型与 t + 1 t+1 t+1层和,近似等效于一个更宽的 t t t层模型,所以在Pre Norm中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”。
输入经过Norm之后,基本上能保持同一量级,然后Attention、FFN这些运算,一般不会大幅改动输入数值的量级(否则容易梯度消失或者爆炸),因此输出的范围也大致相同。
说白了,Pre Norm结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而Post Norm刚刚相反,在上一节中中我们就分析过,每Norm一次就削弱一次恒等分支的权重,所以Post Norm反而是更突出残差分支的,因此Post Norm中的层数更加“足秤”,一旦训练好之后效果更优。