前面介绍的三种语言模型(N元语法、log-linear语言模型和神经网络语言模型)尽管在表示能力上逐渐变强,但是仍然没有摆脱一项桎梏:需要给定一个窗口大小,而且通常这个窗口宽度不会太大。但是语言里通常会有长距离依赖现象,例如He doesn’t have very much confidence in himself,这里himself只是受到句首He的影响——如果句首是She,那么末尾这个词就要变成herself了!这样的现象在英语里比较多,在其它语言里甚至可以超级多:一方面,很多语言形态更加丰富,名词有变格,动词有变位。另一方面,一些语言有自己独特的语法结构,例如德语里丧心病狂的句框结构/框型结构:ich muss morgen zur Uni gehen,这里实意动词gehen与其搭配的情态动词muss遥相呼应,在这种情况下,固定小窗口的语言模型能力显得不足
除却语法的层面,语义上也需要对单词的长距离依赖建模。这里最典型的问题是选择依赖/选择约束问题,即从语义上讲,当上文出现某个单词时,它会潜在地约束之后哪些单词更可能出现,哪些单词更不可能出现。例如前文出现eat,后文很可能出现fork(作为eat的器具)或者friend(作为一起eat的人),但是出现wall感觉就不太合理。这种依赖约束也可能横跨多个单词,难以被简单模型所捕捉
最后,文档内,句与句之间(或者句内),也需要遵守主题一致性和风格一致性。一篇讨论体育的文章不适合突然拐到母猪的产后护理,一篇严谨的学术论文(一般情况下)也不应该突然出现乡间俚语
综上所述,需要一种能力更强,能够捕捉词语长距离依赖关系的模型。在这种背景下,循环神经网络(Recurrent Neural Network, RNN)应运而生
普通RNN的核心思想是引入时刻的概念,此时每一时刻 t t t的隐藏单元 h t \boldsymbol{h}_t ht不再只依赖于这一时刻的模型输入 x t \boldsymbol{x}_t xt,也依赖于上一时刻的隐藏单元 h t − 1 \boldsymbol{h}_{t-1} ht−1,即
h t = { tanh ( W x h x t + W h h h t − 1 + b t ) t ≥ 1 0 o t h e r w i s e \boldsymbol{h}_t = \begin{cases} \tanh(\boldsymbol{W}_{xh}\boldsymbol{x}_t + \boldsymbol{W}_{hh}\boldsymbol{h}_{t-1} + \boldsymbol{b}_t) & t\ge 1 \\ \boldsymbol{0} & {\rm otherwise} \end{cases} ht={tanh(Wxhxt+Whhht−1+bt)0t≥1otherwise
这样一来,若干个时刻之前的隐藏单元状态可以依次传输到当前的隐藏单元,使得模型可以对长距离依赖关系建模。引入输入层和softmax层后,模型的结构为
m t = M ⋅ , e t − 1 h t = { tanh ( W m h m t + W h h h t − 1 + b t ) t ≥ 1 0 o t h e r w i s e p t = s o f t m a x ( W h s h t + b s ) \begin{aligned} \boldsymbol{m}_t &= \boldsymbol{M}_{\cdot, e_{t-1}} \\ \boldsymbol{h}_t &= \begin{cases} \tanh(\boldsymbol{W}_{mh}\boldsymbol{m}_t + \boldsymbol{W}_{hh}\boldsymbol{h}_{t-1} + \boldsymbol{b}_t) & t\ge 1 \\ \boldsymbol{0} & {\rm otherwise} \end{cases} \\ \boldsymbol{p}_t &= {\rm softmax}(\boldsymbol{W}_{hs}\boldsymbol{h}_t + \boldsymbol{b}_s) \end{aligned} mthtpt=M⋅,et−1={tanh(Wmhmt+Whhht−1+bt)0t≥1otherwise=softmax(Whsht+bs)
此时模型不需要再显式依赖上文的多个单词 (即在每一时刻 t t t只需要接收一个单词作为输入),上文的全部信息可以看做都包含在了上游传递过来的上一时刻隐藏状态 h t − 1 \boldsymbol{h}_{t-1} ht−1里
对上式记号稍作变化,
x ( t ) = E ⋅ , e ( t − 1 ) s ( t ) = { tanh ( U x ( t ) + W s ( t − 1 ) + b s ) t ≥ 1 0 o t h e r w i s e o ( t ) = V s ( t ) + b o y ^ ( t ) = s o f t m a x ( o ( t ) ) \begin{aligned} \boldsymbol{x}^{(t)} &= \boldsymbol{E}_{\cdot, e^{(t-1)}} \\ \boldsymbol{s}^{(t)} &= \begin{cases} \tanh(\boldsymbol{U}\boldsymbol{x}^{(t)} + \boldsymbol{W}\boldsymbol{s}^{(t-1)} + \boldsymbol{b}_s) & t\ge 1 \\ \boldsymbol{0} & {\rm otherwise} \end{cases} \\ \boldsymbol{o}^{(t)} &= \boldsymbol{V}\boldsymbol{s}^{(t)} + \boldsymbol{b}_o \\ \hat{\boldsymbol{y}}^{(t)} &= {\rm softmax}(\boldsymbol{o}^{(t)}) \end{aligned} x(t)s(t)o(t)y^(t)=E⋅,e(t−1)={tanh(Ux(t)+Ws(t−1)+bs)0t≥1otherwise=Vs(t)+bo=softmax(o(t))
可以用如下示意图来表示上述RNN结构,其中右半部分可以看做是RNN网络的"展开形式"(图片来源:Nature)
可以将上面的模型描述简写为
x ( t ) = E ⋅ , e ( t − 1 ) s ( t ) = R N N ( x ( t ) , s ( t − 1 ) ) o ( t ) = V s ( t ) + b o y ^ ( t ) = s o f t m a x ( o ( t ) ) \begin{aligned} \boldsymbol{x}^{(t)} &= \boldsymbol{E}_{\cdot, e^{(t-1)}} \\ \boldsymbol{s}^{(t)} &= {\rm RNN}\left(\boldsymbol{x}^{(t)}, \boldsymbol{s}^{(t-1)}\right) \\ \boldsymbol{o}^{(t)} &= \boldsymbol{V}\boldsymbol{s}^{(t)} + \boldsymbol{b}_o \\ \hat{\boldsymbol{y}}^{(t)} &= {\rm softmax}\left(\boldsymbol{o}^{(t)}\right) \end{aligned} x(t)s(t)o(t)y^(t)=E⋅,e(t−1)=RNN(x(t),s(t−1))=Vs(t)+bo=softmax(o(t))
(本节参考了《循环神经网络(RNN)模型与前向反向传播算法》和Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients)
由上面的模型描述可以看出,模型在第 t t t时刻的输出 o ( t ) \boldsymbol{o}^{(t)} o(t)不再只依赖于该时刻的隐藏状态 s ( t ) \boldsymbol{s}^{(t)} s(t),而是还依赖于上游所有时刻的隐藏状态 s ( 1 ) , … , s ( t − 1 ) \boldsymbol{s}^{(1)}, \ldots, \boldsymbol{s}^{(t-1)} s(1),…,s(t−1)。反过来, t t t时刻的隐藏状态 s ( t ) \boldsymbol{s}^{(t)} s(t)不止直接决定该时刻的输出 o ( t ) \boldsymbol{o}^{(t)} o(t),也直接影响了下一时刻隐藏状态 s ( t + 1 ) \boldsymbol{s}^{(t+1)} s(t+1)(此外还间接影响了下游所有隐藏状态,这里先不论),因此误差对 s ( t ) \boldsymbol{s}^{(t)} s(t)的梯度有两个直接来源。又因为 s ( t ) \boldsymbol{s}^{(t)} s(t)由参数 U \boldsymbol{U} U、 W \boldsymbol{W} W和 b \boldsymbol{b} b算出,因此这三个参数的更新不仅受到第 t t t时刻输出 o ( t ) \boldsymbol{o}^{(t)} o(t)的影响,也受到第 t + 1 t+1 t+1时刻隐藏状态 s ( t + 1 ) \boldsymbol{s}^{(t+1)} s(t+1)的影响。即反向传播时梯度会沿"时间"向上游传播,因此RNN的反向传播也被称作为经由时间的反向传播 (Back-Propagation Through Time, BPTT)
下面进行推导:对上面的模型描述引入一个中间记号,并稍作记号上的替换,有
a ( t ) = U x ( t ) + W s ( t − 1 ) + b s ( t ) = tanh ( a ( t ) ) o ( t ) = V s ( t ) + c y ^ ( t ) = s o f t m a x ( o ( t ) ) (1) \begin{aligned} \boldsymbol{a}^{(t)} &= \boldsymbol{U}\boldsymbol{x}^{(t)} + \boldsymbol{W}\boldsymbol{s}^{(t-1)} + \boldsymbol{b} \tag{1} \\ \boldsymbol{s}^{(t)} &= \tanh\left(\boldsymbol{a}^{(t)}\right) \\ \boldsymbol{o}^{(t)} &= \boldsymbol{V}\boldsymbol{s}^{(t)} + \boldsymbol{c} \\ \hat{\boldsymbol{y}}^{(t)} &= {\rm softmax}\left(\boldsymbol{o}^{(t)}\right) \end{aligned} a(t)s(t)o(t)y^(t)=Ux(t)+Ws(t−1)+b=tanh(a(t))=Vs(t)+c=softmax(o(t))(1)
假设使用交叉熵作为模型的损失函数,由于每个时刻都有一个最终输出 y ^ ( t ) \hat{\boldsymbol{y}}^{(t)} y^(t),因此模型的总损失函数为
E ( y , y ^ ) = ∑ t E ( t ) ( y ( t ) , y ^ ( t ) ) E ( t ) ( y ( t ) , y ^ ( t ) ) = − ∑ i = 1 ∣ V ∣ y i ( t ) log y ^ i ( t ) \begin{aligned} E(\boldsymbol{y}, \hat{\boldsymbol{y}}) &= \sum_t E^{(t)}\left(\boldsymbol{y}^{(t)}, \hat{\boldsymbol{y}}^{(t)}\right) \\ E^{(t)}\left(\boldsymbol{y}^{(t)}, \hat{\boldsymbol{y}}^{(t)}\right) &= -\sum_{i=1}^{|V|}y^{(t)}_{i}\log\hat{y}^{(t)}_{i} \end{aligned} E(y,y^)E(t)(y(t),y^(t))=t∑E(t)(y(t),y^(t))=−i=1∑∣V∣yi(t)logy^i(t)
假设模型各参数、输入、输出和中间结果维度为
y ^ ( t ) , o ( t ) , c ∈ R ∣ V ∣ × 1 s ( t ) , a ( t ) , b ∈ R h × 1 x ∈ R d × 1 U ∈ R h × d W ∈ R h × h V ∈ R ∣ V ∣ × h \begin{aligned} \hat{\boldsymbol{y}}^{(t)}, \boldsymbol{o}^{(t)}, \boldsymbol{c} &\in \mathbb{R}^{|V| \times 1} \\ \boldsymbol{s}^{(t)}, \boldsymbol{a}^{(t)}, \boldsymbol{b} &\in \mathbb{R}^{h \times 1} \\ \boldsymbol{x} &\in \mathbb{R}^{d\times 1} \\ \boldsymbol{U} &\in \mathbb{R}^{h \times d} \\ \boldsymbol{W} &\in \mathbb{R}^{h \times h} \\ \boldsymbol{V} &\in \mathbb{R}^{|V| \times h} \\ \end{aligned} y^(t),o(t),cs(t),a(t),bxUWV∈R∣V∣×1∈Rh×1∈Rd×1∈Rh×d∈Rh×h∈R∣V∣×h
由于 V \boldsymbol{V} V和 c \boldsymbol{c} c不依赖于之前时刻的内容,因此比较容易计算。考虑到
∂ E ∂ o ( t ) = ∂ E ∂ E ( t ) ∂ E ( t ) ∂ o ( t ) = ∂ E ( t ) ∂ o ( t ) = y ^ ( t ) − y ( t ) : = δ 1 ( t ) \frac{\partial E}{\partial \boldsymbol{o}^{(t)}} = \frac{\partial E}{\partial E^{(t)}}\frac{\partial E^{(t)}}{\partial \boldsymbol{o}^{(t)}} =\frac{\partial E^{(t)}}{\partial \boldsymbol{o}^{(t)}} = \hat{\boldsymbol{y}}^{(t)} - \boldsymbol{y}^{(t)} := \boldsymbol{\delta}^{(t)}_1 ∂o(t)∂E=∂E(t)∂E∂o(t)∂E(t)=∂o(t)∂E(t)=y^(t)−y(t):=δ1(t)
可得
∂ E ∂ c = ∑ t = 1 T ∂ E ( t ) ∂ o ( t ) ∂ o ( t ) ∂ c = ∑ t = 1 T δ 1 ( t ) ∂ E ∂ V = ∑ t = 1 T ∂ E ( t ) ∂ o ( t ) ∂ o ( t ) ∂ V = δ 1 ( t ) ( s ( t ) ) T = ∑ t = 1 T δ 1 ( t ) × s ( t ) \begin{aligned} \frac{\partial E}{\partial \boldsymbol{c}} &= \sum_{t=1}^T\frac{\partial E^{(t)}}{\partial \boldsymbol{o}^{(t)}}\frac{\partial \boldsymbol{o}^{(t)}}{\partial \boldsymbol{c}} = \sum_{t=1}^T\boldsymbol{\delta}_1^{(t)} \\ \frac{\partial E}{\partial \boldsymbol{V}} &= \sum_{t=1}^T\frac{\partial E^{(t)}}{\partial \boldsymbol{o}^{(t)}} \frac{\partial \boldsymbol{o}^{(t)}}{\partial \boldsymbol{V}} = \boldsymbol{\delta}_1^{(t)}\left(\boldsymbol{s}^{(t)}\right)^\mathsf{T} = \sum_{t=1}^T\boldsymbol{\delta}_1^{(t)}\times \boldsymbol{s}^{(t)} \end{aligned} ∂c∂E∂V∂E=t=1∑T∂o(t)∂E(t)∂c∂o(t)=t=1∑Tδ1(t)=t=1∑T∂o(t)∂E(t)∂V∂o(t)=δ1(t)(s(t))T=t=1∑Tδ1(t)×s(t)
对于 W \boldsymbol{W} W、 U \boldsymbol{U} U、 b \boldsymbol{b} b,由于它们的梯度更新均依赖于 s ( t ) \boldsymbol{s}^{(t)} s(t),因此可以定义一个中间变量
δ ( t ) = ∂ E ∂ s ( t ) = ∑ i = t T ∂ E ( i ) ∂ s ( t ) = ∂ E ( t ) ∂ o ( t ) ∂ o ( t ) ∂ s ( t ) + ∂ E ∂ s ( t + 1 ) ∂ s ( t + 1 ) ∂ a ( t + 1 ) ∂ a ( t + 1 ) ∂ s ( t ) = V T ( y ^ ( t ) − y ( t ) ) + W T d i a g ( 1 − tanh a ( t + 1 ) ⊙ tanh a ( t + 1 ) ) δ ( t + 1 ) \begin{aligned} \boldsymbol{\delta}^{(t)} &= \frac{\partial E}{\partial \boldsymbol{s}^{(t)}} = \sum_{i=t}^{T} \frac{\partial E^{(i)}}{\partial \boldsymbol{s}^{(t)}} \\ &= \frac{\partial E^{(t)}}{\partial \boldsymbol{o}^{(t)}}\frac{\partial \boldsymbol{o}^{(t)}}{\partial \boldsymbol{s}^{(t)}} + \frac{\partial E}{\partial \boldsymbol{s}^{(t+1)}}\frac{\partial \boldsymbol{s}^{(t+1)}}{\partial \boldsymbol{a}^{(t+1)}}\frac{\partial \boldsymbol{a}^{(t+1)}}{\partial \boldsymbol{s}^{(t)}} \\ &= \boldsymbol{V}^{\mathsf{T}}\left(\hat{\boldsymbol{y}}^{(t)} - \boldsymbol{y}^{(t)}\right) + \boldsymbol{W}^\mathsf{T}{\rm diag}\left(\boldsymbol{1}-\tanh\boldsymbol{a}^{(t+1)} \odot \tanh\boldsymbol{a}^{(t+1)}\right)\boldsymbol{\delta}^{(t+1)} \end{aligned} δ(t)=∂s(t)∂E=i=t∑T∂s(t)∂E(i)=∂o(t)∂E(t)∂s(t)∂o(t)+∂s(t+1)∂E∂a(t+1)∂s(t+1)∂s(t)∂a(t+1)=VT(y^(t)−y(t))+WTdiag(1−tanha(t+1)⊙tanha(t+1))δ(t+1)
假设RNN展开序列长度为 T T T,则 δ ( T ) \boldsymbol{\delta}^{(T)} δ(T)不会有上式的第二项(因为没有下游状态可以传递上来)
在有了这个中间变量后,可以有
∂ E ∂ W = ∑ t = 1 T ∂ E ∂ s ( t ) ∂ s ( t ) ∂ a ( t ) ∂ a ( t ) ∂ W = ∑ t = 1 T d i a g ( 1 − tanh a ( t ) ⊙ tanh a ( t ) ) δ ( t ) ( s ( t − 1 ) ) T ∂ E ∂ b = ∑ t = 1 T ∂ E ∂ s ( t ) ∂ s ( t ) ∂ a ( t ) ∂ a ( t ) ∂ b = ∑ t = 1 T d i a g ( 1 − tanh a ( t ) ⊙ tanh a ( t ) ) δ ( t ) ∂ E ∂ U = ∑ t = 1 T ∂ E ∂ s ( t ) ∂ s ( t ) ∂ a ( t ) ∂ a ( t ) ∂ U = ∑ t = 1 T d i a g ( 1 − tanh a ( t ) ⊙ tanh a ( t ) ) δ ( t ) ( x ( t ) ) T \begin{aligned} \frac{\partial E}{\partial \boldsymbol{W}} &= \sum_{t=1}^T \frac{\partial E}{\partial \boldsymbol{s}^{(t)}}\frac{\partial \boldsymbol{s}^{(t)}}{\partial \boldsymbol{a}^{(t)}}\frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{W}} = \sum_{t=1}^T{\rm diag}\left(\boldsymbol{1}-\tanh \boldsymbol{a}^{(t)} \odot \tanh \boldsymbol{a}^{(t)}\right)\boldsymbol{\delta}^{(t)}\left(\boldsymbol{s}^{(t-1)}\right)^\mathsf{T} \\ \frac{\partial E}{\partial \boldsymbol{b}} &= \sum_{t=1}^T \frac{\partial E}{\partial \boldsymbol{s}^{(t)}}\frac{\partial \boldsymbol{s}^{(t)}}{\partial \boldsymbol{a}^{(t)}}\frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{b}} = \sum_{t=1}^T{\rm diag}\left(\boldsymbol{1}-\tanh \boldsymbol{a}^{(t)} \odot \tanh\boldsymbol{a}^{(t)}\right)\boldsymbol{\delta}^{(t)} \\ \frac{\partial E}{\partial \boldsymbol{U}} &= \sum_{t=1}^T \frac{\partial E}{\partial \boldsymbol{s}^{(t)}}\frac{\partial \boldsymbol{s}^{(t)}}{\partial \boldsymbol{a}^{(t)}}\frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{U}} = \sum_{t=1}^T{\rm diag}\left(\boldsymbol{1}-\tanh\boldsymbol{a}^{(t)} \odot \tanh\boldsymbol{a}^{(t)}\right)\boldsymbol{\delta}^{(t)}\left(\boldsymbol{x}^{(t)}\right)^\mathsf{T} \end{aligned} ∂W∂E∂b∂E∂U∂E=t=1∑T∂s(t)∂E∂a(t)∂s(t)∂W∂a(t)=t=1∑Tdiag(1−tanha(t)⊙tanha(t))δ(t)(s(t−1))T=t=1∑T∂s(t)∂E∂a(t)∂s(t)∂b∂a(t)=t=1∑Tdiag(1−tanha(t)⊙tanha(t))δ(t)=t=1∑T∂s(t)∂E∂a(t)∂s(t)∂U∂a(t)=t=1∑Tdiag(1−tanha(t)⊙tanha(t))δ(t)(x(t))T
尽管期望RNN能够捕捉到文本中的长距离依赖,但是真正使用时却会发现事与愿违,发生梯度消失或者梯度爆炸的现象,前者指的是训练过程中向远距离传播的梯度范数指数下降到接近为0,使得远端参数无法被更新 (更准确地说,由于RNN中一些参数来自各个时刻梯度的和,因此这些参数本身的梯度不会太小,仍然可以被更新。只不过主导更新的梯度来源于比较近的时刻,远端的时刻贡献不大);后者指的是训练过程中向远距离传播的梯度范数指数增长至一个巨大的数值(有时只能用NaN表示),无法收敛。[Pascanu2013]对此问题做了一个分析:如果从另一个角度对BPTT进行推导,不引进中间变量,展开递推关系,并将所有依赖于隐藏层状态的参数都记为 θ \boldsymbol{\theta} θ,有
∂ E ∂ θ = ∑ 1 ≤ t ≤ T ∂ E ( t ) ∂ θ ∂ E ( t ) ∂ θ = ∑ 1 ≤ k ≤ t ( ∂ E ( t ) ∂ a ( t ) ∂ a ( t ) ∂ a ( k ) ∂ + a ( k ) ∂ θ ) ∂ a ( t ) ∂ a ( k ) = ∏ t ≥ i > k ∂ a ( i ) ∂ a ( i − 1 ) = ∏ t ≥ i > k W T d i a g ( 1 − tanh a ( i − 1 ) ⊙ tanh a ( i − 1 ) ) (2) \begin{aligned} \frac{\partial E}{\partial \boldsymbol{\theta}} &= \sum_{1\le t\le T}\frac{\partial E^{(t)}}{\partial \boldsymbol{\theta}} \\ \frac{\partial E^{(t)}}{\partial \boldsymbol{\theta}} &= \sum_{1\le k \le t}\left(\frac{\partial E^{(t)}}{\partial \boldsymbol{a}^{(t)}}\frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{a}^{(k)}}\frac{\partial^+ \boldsymbol{a}^{(k)}}{\partial \boldsymbol{\theta}}\right) \\ \frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{a}^{(k)}} &= \prod_{t \ge i > k} \frac{\partial \boldsymbol{a}^{(i)}}{\partial \boldsymbol{a}^{(i-1)}} = \prod_{t \ge i > k}\boldsymbol{W}^\mathsf{T}{\rm diag}\left(\boldsymbol{1}-\tanh\boldsymbol{a}^{(i-1)}\odot \tanh\boldsymbol{a}^{(i-1)}\right)\tag{2} \end{aligned} ∂θ∂E∂θ∂E(t)∂a(k)∂a(t)=1≤t≤T∑∂θ∂E(t)=1≤k≤t∑(∂a(t)∂E(t)∂a(k)∂a(t)∂θ∂+a(k))=t≥i>k∏∂a(i−1)∂a(i)=t≥i>k∏WTdiag(1−tanha(i−1)⊙tanha(i−1))(2)
这里 ∂ + a ( k ) / ∂ θ \partial^+ \boldsymbol{a}^{(k)}/\partial \boldsymbol{\theta} ∂+a(k)/∂θ按文章中的说法是"immediate partial derivative",指的意思是将 a ( k − 1 ) \boldsymbol{a}^{(k-1)} a(k−1)看作是一个关于 θ \boldsymbol{\theta} θ的常量,即 ∂ + a ( k ) / ∂ W = s ( k − 1 ) = tanh ( a ( k − 1 ) ) \partial^+ \boldsymbol{a}^{(k)}/\partial \boldsymbol{W} = \boldsymbol{s}^{(k-1)}= \tanh\left(\boldsymbol{a}^{(k-1)}\right) ∂+a(k)/∂W=s(k−1)=tanh(a(k−1)),不再做进一步展开。如果展开的话, a ( t ) \boldsymbol{a}^{(t)} a(t)将会是一个对 W \boldsymbol{W} W的 t t t次多项式,求导起来麻烦且不直观
以 W \boldsymbol{W} W为例,其在RNN展开以后的表达式里会多次出现,所以损失函数对 W \boldsymbol{W} W的每一次出现单独求导,再求和,本质上还是对两个函数相乘的求导,即 ( u v ) ′ = u ′ v + u v ′ (uv)' = u'v + uv' (uv)′=u′v+uv′。其中 u ′ v u'v u′v就是immediate partial derivative, u v ′ uv' uv′是递归项, W \boldsymbol{W} W在 u u u和 v v v里都有,所以最后的导数根据全微分定理需要分别求导再求和
——大神同事Towser对immediate partial derivative的进一步解释
注意到(2)式里 d i a g \rm diag diag函数的参数实际上是激活函数 tanh \tanh tanh的导数。如果将激活函数泛化为 σ \sigma σ,则(2)式可以写为
∂ a ( t ) ∂ a ( k ) = ∏ t ≥ i > k ∂ a ( i ) ∂ a ( i − 1 ) = ∏ t ≥ i > k W T d i a g ( σ ′ ( a ( i − 1 ) ) ) (3) \frac{\partial \boldsymbol{a}^{(t)}}{\partial \boldsymbol{a}^{(k)}} = \prod_{t \ge i > k} \frac{\partial \boldsymbol{a}^{(i)}}{\partial \boldsymbol{a}^{(i-1)}} = \prod_{t \ge i > k}\boldsymbol{W}^\mathsf{T}{\rm diag}\left(\sigma'(\boldsymbol{a}^{(i-1)})\right)\tag{3} ∂a(k)∂a(t)=t≥i>k∏∂a(i−1)∂a(i)=t≥i>k∏WTdiag(σ′(a(i−1)))(3)
[Pascanu2013]的arxiv版本附录里证明了,当激活函数 σ \sigma σ是恒等函数时,有
进一步地,假设 ∣ σ ′ ( a ) ∣ |\sigma'(a)| ∣σ′(a)∣有界,那么由于 d i a g ( σ ′ ( a ( k ) ) {\rm diag}(\sigma'(\boldsymbol{a}^{(k)}) diag(σ′(a(k))是一个对角矩阵,且矩阵的2-范数是其最大奇异值(矩阵奇异值与矩阵范数之间有什么联系? - 边际函数的回答),因此 ∥ d i a g ( σ ′ ( a ( k ) ) ∥ \|{\rm diag}(\sigma'(\boldsymbol{a}^{(k)})\| ∥diag(σ′(a(k))∥就是对角线上绝对值最大的元素,因此 ∥ d i a g ( σ ′ ( a ( k ) ) ) ∥ \|{\rm diag}(\sigma'(\boldsymbol{a}^{(k)}))\| ∥diag(σ′(a(k)))∥也是有界的,记为 γ \gamma γ。假设 ∥ W ∥ \|\boldsymbol{W}\| ∥W∥的最大奇异值 λ 1 < 1 / γ \lambda_1 < 1/\gamma λ1<1/γ,由于 ∥ W ∥ = λ 1 \|\boldsymbol{W}\| = \lambda_1 ∥W∥=λ1,因此 ∥ W ∥ < 1 / γ \|\boldsymbol{W}\| < 1/\gamma ∥W∥<1/γ
由矩阵范数的两个性质
可得
∀ k , ∥ ∂ a ( k + 1 ) ∂ a ( k ) ∥ = ∥ W T d i a g ( σ ′ ( a ( k ) ) ) ∥ ≤ ∥ W T ∥ ∥ d i a g ( σ ′ ( a ( k ) ) ) ∥ < 1 γ γ = 1 \begin{aligned} \forall k, \left\|\frac{\partial \boldsymbol{a}^{(k+1)}}{\partial \boldsymbol{a}^{(k)}}\right\| &= \left\|\boldsymbol{W}^\mathsf{T}{\rm diag}\left(\sigma'(\boldsymbol{a}^{(k)})\right)\right\| \le \left\|\boldsymbol{W}^\mathsf{T}\right\|\left\|{\rm diag}\left(\sigma'(\boldsymbol{a}^{(k)})\right)\right\| < \frac{1}{\gamma}\gamma = 1 & \end{aligned} ∀k,∥∥∥∥∂a(k)∂a(k+1)∥∥∥∥=∥∥∥WTdiag(σ′(a(k)))∥∥∥≤∥∥∥WT∥∥∥∥∥∥diag(σ′(a(k)))∥∥∥<γ1γ=1
因此, ∃ η ∈ R → ∀ k , ∥ ∂ a ( k + 1 ) ∂ a ( k ) ∥ ≤ η < 1 \exists \eta \in \mathbb{R} \rightarrow \forall k, \left\|\frac{\partial \boldsymbol{a}^{(k+1)}}{\partial \boldsymbol{a}^{(k)}}\right\| \le \eta < 1 ∃η∈R→∀k,∥∥∥∂a(k)∂a(k+1)∥∥∥≤η<1,这意味着
∥ ∂ E ( t ) ∂ a ( t ) ( ∏ i = k t − 1 ∂ a ( i + 1 ) ∂ a ( i ) ) ∥ ≤ η t − k ∥ ∂ E ( t ) ∂ a ( t ) ∥ \left\|\frac{\partial E^{(t)}}{\partial \boldsymbol{a}^{(t)}}\left(\prod_{i=k}^{t-1}\frac{\partial \boldsymbol{a}^{(i+1)}}{\partial \boldsymbol{a}^{(i)}}\right)\right\| \le \eta^{t-k}\left\|\frac{\partial E^{(t)}}{\partial \boldsymbol{a}^{(t)}}\right\| ∥∥∥∥∥∂a(t)∂E(t)(i=k∏t−1∂a(i)∂a(i+1))∥∥∥∥∥≤ηt−k∥∥∥∥∂a(t)∂E(t)∥∥∥∥
由于 η < 1 \eta < 1 η<1,因此由上式可知,当 t − k t-k t−k比较大时,梯度会很快下降到0,即若 ∥ W ∥ < 1 / γ \|\boldsymbol{W}\| < 1/\gamma ∥W∥<1/γ,其中 γ \gamma γ是 ∥ d i a g ( σ ′ ( a ( k ) ) ) ∥ \|{\rm diag}(\sigma'(\boldsymbol{a}^{(k)}))\| ∥diag(σ′(a(k)))∥的上界,则一定会发生梯度消失。反过来,如果发生梯度爆炸,则 ∥ W ∥ = λ 1 > 1 / γ \|\boldsymbol{W}\|= \lambda_1 > 1/\gamma ∥W∥=λ1>1/γ。对于tanh函数, γ = 1 \gamma = 1 γ=1,对于sigmoid函数, γ = 1 / 4 \gamma = 1/4 γ=1/4
关于梯度爆炸,[Pascanu2013]在分析了一个极简单的网络 (只有一个隐藏层节点) 后得出一个结论:如果发生了梯度爆炸,那么在某个方向上误差表面的曲率会非常大,形成一堵墙 (或者也可以称为是一个悬崖) 。SGD优化算法将参数优化到这堵墙时,会向着某个方向迈出很大一步,再次进入损失函数值比较大的区域,破坏学习过程。因此,该文章给出了一个比较有效,在现在也通用的算法——梯度(范数)截断法 (norm clipping),核心思想是如果梯度的范数大于某个指定阈值时,将梯度按照指定阈值与当前梯度范数的比例缩放
g ^ ← ∂ E ∂ θ g ^ ← t h r e s h o l d ∥ g ^ ∥ g ^ i f ∥ g ^ ∥ ≥ t h r e s h o l d \begin{aligned} \hat{\boldsymbol{g}} &\leftarrow \frac{\partial E}{\partial \boldsymbol{\theta}} \\ \hat{\boldsymbol{g}} &\leftarrow \frac{\rm threshold}{\|\hat{\boldsymbol{g}}\|}\hat{\boldsymbol{g}}\ \ {\rm if\ }\|\hat{\boldsymbol{g}}\| \ge {\rm threshold} \end{aligned} g^g^←∂θ∂E←∥g^∥thresholdg^ if ∥g^∥≥threshold
这样,当参数到达墙附近时,截断过的梯度会把参数推回到墙边比较光滑,曲率比较小的区域,如下图 (来自于花书图10.17) 所示
在TensorFlow中使用梯度截断时,注意不能直接调用优化器的minimize
方法,而是要先compute_gradients
获得梯度,然后调用tf.clip_by_global_norm
方法做截断,再调用apply_gradients
方法来更新参数,例如
optimizer = tf.train.AdamOptimizer(1e-3)
gradients, variables = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, NORM_THRESHOLD)
optimize = optimizer.apply_gradients(zip(gradients, variables))
TensorFlow提供了若干种方法来做截断(1.x有4种,2.0有3种),但是严格来说,只有clip_by_global_norm
方法实现了[Pascanu2013]的梯度截断方法。不过这个方法要收集所有张量,所以比另一个方法clip_by_norm
要慢一些,更具体的解释可以参看这个StackOverflow的回答。clip_by_global_norm
的实现大致如下
def global_norm(tensors):
half_squared_norms = [l2_loss(t) for t in tensors] # l2_loss = sum(t_i ** 2) / 2
return sqrt(reduce_sum(half_squared_norms) * 2.)
def clip_by_global_norm(tensors, threshold):
global_norm = global_norm(tensors)
scale = threshold * (min(1. / global_norm, 1. / threshold))
return [t * scale for t in tensors], global_norm
对于RNN训练过程中出现的梯度消失 (更严格地说,是长距离依赖时刻的梯度消失),CS224讲义里给出了一些简单的处理办法,例如激活函数用ReLU,参数初始化成正交矩阵等。但是更常见的做法是使用RNN的一个改进方案LSTM (及其变种GRU),这两种结构的核心都是基于了门控单元
(在继续介绍基于门控单元的RNN之前,先来看一下将矩阵初始化成正交矩阵的好处。首先,将矩阵 W \boldsymbol{W} W做特征值分解,可以得到 W = Q Λ Q T \boldsymbol{W} = \boldsymbol{Q\Lambda Q}^\mathsf{T} W=QΛQT。由于 Q T Q = I \boldsymbol{Q}^\mathsf{T}\boldsymbol{Q} = \boldsymbol{I} QTQ=I,因此 W n = Q Λ n Q T \boldsymbol{W}^n = \boldsymbol{Q\Lambda}^n\boldsymbol{Q}^\mathsf{T} Wn=QΛnQT。因此,当矩阵的所有模长的绝对值都为1时,不会发生梯度消失或梯度爆炸。而正交矩阵正好满足了这个条件)
基于门控单元的RNN最早由Hochreiter和Schmidhuber在1997年提出[Hochreiter1997],称为"长短期记忆网络" (Long-Short Term Memory Network,简称LSTM)。LSTM在传统RNN的基础上,加入了两个门控单元:输入门防止当前隐藏状态不被无关的输入干扰,学习应在何时释放误差;输出门防止当前隐藏状态干扰与之无关的单元,学习当前单元应该捕获何种误差。但是,每个单元的状态通常都会单调增长 (只是不接受无关输入而已),导致激活值最后进入饱和区域,因此梯度消失。此外,这还会导致单元的输出近似等于输出门的输出,将整个单元退化为普通RNN单元。因此,在[Gers2000]里同组研究人员加入了一个新的门遗忘门,来重置单元中过时而无用的内容
具体地,LSTM的模型结构如下所示。此处将前面隐藏层状态 s \boldsymbol{s} s记为 h \boldsymbol{h} h,sigmoid函数写为 σ \sigma σ
如果不考虑三个门的具体计算过程,并将 σ \sigma σ的结果0-1二值化,则LSTM的结构如下图所示 (改编自张皓@知乎,三次简化一张图:一招理解LSTM/GRU门控机制)
可见当输入门和输出门闭合,遗忘门打开时,LSTM退化为普通RNN
LSTM隐藏层参数数量是普通RNN的四倍(假设输入 x ∈ R d × 1 \boldsymbol{x} \in \mathbb{R}^{d \times 1} x∈Rd×1,隐藏节点数,即LSTM单元数为 h h h,则LSTM隐藏层共有 4 h ( h + d + 1 ) 4h(h +d+1) 4h(h+d+1)个参数),反向传播时梯度计算更加复杂。为简单起见,将各个门激活前的值作如下表示,并忽略偏置项
c ~ ( t ) = tanh ( W c h ( t − 1 ) + U c x ( t ) ) = tanh ( c ^ ( t ) ) i ( t ) = σ ( W i h ( t − 1 ) + U i x ( t ) ) = σ ( i ^ ( t ) ) f ( t ) = σ ( W f h ( t − 1 ) + U f x ( t ) ) = σ ( f ^ ( t ) ) o ( t ) = σ ( W o h ( t − 1 ) + U o x ( t ) ) = σ ( o ^ ( t ) ) \begin{aligned} \tilde{\boldsymbol{c}}^{(t)} &= \tanh\left(\boldsymbol{W}_c\boldsymbol{h}^{(t-1)} + \boldsymbol{U}_c\boldsymbol{x}^{(t)}\right) = \tanh\left(\boldsymbol{\hat{c}}^{(t)}\right) \\ \boldsymbol{i}^{(t)} &= \sigma\left(\boldsymbol{W}_i\boldsymbol{h}^{(t-1)} + \boldsymbol{U}_i\boldsymbol{x}^{(t)}\right) = \sigma\left(\boldsymbol{\hat{i}}^{(t)}\right) \\ \boldsymbol{f}^{(t)} &= \sigma\left(\boldsymbol{W}_f\boldsymbol{h}^{(t-1)} + \boldsymbol{U}_f\boldsymbol{x}^{(t)}\right) = \sigma\left(\boldsymbol{\hat{f}}^{(t)}\right) \\ \boldsymbol{o}^{(t)} &= \sigma\left(\boldsymbol{W}_o\boldsymbol{h}^{(t-1)} + \boldsymbol{U}_o\boldsymbol{x}^{(t)}\right) = \sigma\left(\boldsymbol{\hat{o}}^{(t)}\right) \end{aligned} c~(t)i(t)f(t)o(t)=tanh(Wch(t−1)+Ucx(t))=tanh(c^(t))=σ(Wih(t−1)+Uix(t))=σ(i^(t))=σ(Wfh(t−1)+Ufx(t))=σ(f^(t))=σ(Woh(t−1)+Uox(t))=σ(o^(t))
引入一个中间向量 z ( t ) \boldsymbol{z}^{(t)} z(t),有
z ( t ) = [ c ^ ( t ) i ^ ( t ) f ^ ( t ) o ^ ( t ) ] = [ W c U c W i U i W f U f W o U o ] [ h ( t − 1 ) x ( t ) ] : = M s ( t ) \boldsymbol{z}^{(t)} = \left[\begin{matrix}\hat{\boldsymbol{c}}^{(t)} \\ \hat{\boldsymbol{i}}^{(t)} \\ \hat{\boldsymbol{f}}^{(t)} \\ \hat{\boldsymbol{o}}^{(t)} \\ \end{matrix}\right] = \left[\begin{matrix}\boldsymbol{W}_c & \boldsymbol{U}_c \\ \boldsymbol{W}_i & \boldsymbol{U}_i \\ \boldsymbol{W}_f & \boldsymbol{U}_f \\ \boldsymbol{W}_o & \boldsymbol{U}_o \\\end{matrix}\right]\left[\begin{matrix}\boldsymbol{h}^{(t-1)} \\ \boldsymbol{x}^{(t)}\end{matrix}\right] := \boldsymbol{M}\boldsymbol{s}^{(t)} z(t)=⎣⎢⎢⎢⎡c^(t)i^(t)f^(t)o^(t)⎦⎥⎥⎥⎤=⎣⎢⎢⎡WcWiWfWoUcUiUfUo⎦⎥⎥⎤[h(t−1)x(t)]:=Ms(t)
记损失函数对 h ( t ) \boldsymbol{h}^{(t)} h(t)的偏导数为 δ h ( t ) \delta h^{(t)} δh(t),有
δ o ( t ) = : ∂ E ∂ o ( t ) = ∂ E ∂ h ( t ) ∂ h ( t ) ∂ o ( t ) = δ h ( t ) ⊙ tanh c ( t ) δ c ( t ) = δ c c u r ( t ) + δ c n e x t ( t ) δ c c u r ( t ) = : ∂ E ∂ c ( t ) = ∂ E ∂ h ( t ) ∂ h ( t ) ∂ c ( t ) = δ h ( t ) ⊙ o ( t ) ⊙ ( 1 − tanh c ( t ) ⊙ tanh c ( t ) ) δ i ( t ) = : ∂ E ∂ i ( t ) = ∂ E ∂ c ( t ) ∂ c ( t ) ∂ i ( t ) = δ c ( t ) ⊙ c ~ ( t ) δ f ( t ) = : ∂ E ∂ f ( t ) = ∂ E ∂ c ( t ) ∂ c ( t ) ∂ f ( t ) = δ c ( t ) ⊙ c ( t − 1 ) δ c ~ ( t ) = : ∂ E ∂ c ~ ( t ) = ∂ E ∂ c ( t ) ∂ c ( t ) ∂ c ~ ( t ) = δ c ( t ) ⊙ i ( t ) δ c n e x t ( t − 1 ) = : ∂ E ∂ c ( t − 1 ) = ∂ E ∂ c ( t ) ∂ c ( t ) ∂ c ( t − 1 ) = δ c ( t ) ⊙ f ( t ) δ c ^ ( t ) = : ∂ E ∂ c ^ ( t ) = δ c ~ ( t ) ⊙ ( 1 − tanh c ^ ( t ) ⊙ tanh c ^ ( t ) ) δ i ^ ( t ) = : ∂ E ∂ i ^ ( t ) = δ i ( t ) ⊙ i ^ ( t ) ⊙ ( 1 − i ^ ( t ) ) δ f ^ ( t ) = : ∂ E ∂ f ^ ( t ) = δ f ( t ) ⊙ f ^ ( t ) ⊙ ( 1 − f ^ ( t ) ) δ o ^ ( t ) = : ∂ E ∂ o ^ ( t ) = δ o ( t ) ⊙ o ^ ( t ) ⊙ ( 1 − o ^ ( t ) ) δ z ( t ) = : ∂ E ∂ z ( t ) = [ δ c ^ ( t ) T δ i ^ ( t ) T δ f ^ ( t ) T δ o ^ ( t ) T ] T δ s ( t ) = : ∂ E ∂ s ( t ) = M T δ z ( t ) δ h ( t − 1 ) = δ s ( t ) [ : d ] δ M ( t ) = δ z ( t ) s ( t ) T δ M = ∑ t = 1 T δ M ( t ) \begin{aligned} \delta o^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{o}^{(t)}} = \frac{\partial E}{\partial \boldsymbol{h}^{(t)}}\frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{o}^{(t)}} = \delta h^{(t)} \odot \tanh \boldsymbol{c}^{(t)} \\ \delta c^{(t)} &= \delta c^{(t)}_{\rm cur} + \delta c_{\rm next}^{(t)} \\ \delta c^{(t)}_{\rm cur} &=: \frac{\partial E}{\partial \boldsymbol{c}^{(t)}} = \frac{\partial E}{\partial \boldsymbol{h}^{(t)}}\frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{c}^{(t)}} = \delta h^{(t)} \odot \boldsymbol{o}^{(t)} \odot \left(\boldsymbol{1}-\tanh \boldsymbol{c}^{(t)} \odot \tanh \boldsymbol{c}^{(t)}\right) \\ \delta i^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{i}^{(t)}} = \frac{\partial E}{\partial \boldsymbol{c}^{(t)}}\frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{i}^{(t)}} = \delta c^{(t)}\odot \tilde{\boldsymbol{c}}^{(t)} \\ \delta f^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{f}^{(t)}} = \frac{\partial E}{\partial \boldsymbol{c}^{(t)}}\frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{f}^{(t)}} = \delta c^{(t)}\odot {\boldsymbol{c}}^{(t-1)} \\ \delta \tilde{c}^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{\tilde{c}}^{(t)}} = \frac{\partial E}{\partial \boldsymbol{c}^{(t)}}\frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{\tilde{c}}^{(t)}} = \delta c^{(t)}\odot \boldsymbol{i}^{(t)} \\ \delta c^{(t-1)}_{\rm next} &=: \frac{\partial E}{\partial \boldsymbol{c}^{(t-1)}} = \frac{\partial E}{\partial \boldsymbol{c}^{(t)}}\frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{c}^{(t-1)}} = \delta c^{(t)}\odot \boldsymbol{f}^{(t)} \\ \delta \hat{c}^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{\hat{c}}^{(t)}} = \delta\tilde{c}^{(t)}\odot\left(\boldsymbol{1}-\tanh \boldsymbol{\hat{c}}^{(t)} \odot \tanh \boldsymbol{\hat{c}}^{(t)}\right) \\ \delta \hat{i}^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{\hat{i}}^{(t)}} = \delta i^{(t)}\odot \hat{\boldsymbol{i}}^{(t)} \odot \left(\boldsymbol{1}-\hat{\boldsymbol{i}}^{(t)}\right) \\ \delta \hat{f}^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{\hat{f}}^{(t)}} = \delta f^{(t)}\odot \hat{\boldsymbol{f}}^{(t)} \odot \left(\boldsymbol{1}-\hat{\boldsymbol{f}}^{(t)}\right) \\ \delta \hat{o}^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{\hat{o}}^{(t)}} = \delta o^{(t)}\odot \hat{\boldsymbol{o}}^{(t)} \odot \left(\boldsymbol{1}-\hat{\boldsymbol{o}}^{(t)}\right) \\ \delta z^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{z}^{(t)}} = \left[\begin{matrix}\delta {\hat{c}^{(t)}}^\mathsf{T} & \delta{\hat{i}^{(t)}}^\mathsf{T} & \delta{\hat{f}^{(t)}}^\mathsf{T} & \delta{\hat{o}^{(t)}}^\mathsf{T}\end{matrix}\right]^\mathsf{T} \\ \delta s^{(t)} &=: \frac{\partial E}{\partial \boldsymbol{s}^{(t)}} = \boldsymbol{M}^\mathsf{T}\delta z^{(t)} \\ \delta h^{(t-1)} &= \delta s^{(t)}[:d] \\ \delta M^{(t)} &= \delta z^{(t)}{\boldsymbol{s}^{(t)}}^{\mathsf{T}} \\ \delta M &= \sum_{t=1}^T\delta M^{(t)} \end{aligned} δo(t)δc(t)δccur(t)δi(t)δf(t)δc~(t)δcnext(t−1)δc^(t)δi^(t)δf^(t)δo^(t)δz(t)δs(t)δh(t−1)δM(t)δM=:∂o(t)∂E=∂h(t)∂E∂o(t)∂h(t)=δh(t)⊙tanhc(t)=δccur(t)+δcnext(t)=:∂c(t)∂E=∂h(t)∂E∂c(t)∂h(t)=δh(t)⊙o(t)⊙(1−tanhc(t)⊙tanhc(t))=:∂i(t)∂E=∂c(t)∂E∂i(t)∂c(t)=δc(t)⊙c~(t)=:∂f(t)∂E=∂c(t)∂E∂f(t)∂c(t)=δc(t)⊙c(t−1)=:∂c~(t)∂E=∂c(t)∂E∂c~(t)∂c(t)=δc(t)⊙i(t)=:∂c(t−1)∂E=∂c(t)∂E∂c(t−1)∂c(t)=δc(t)⊙f(t)=:∂c^(t)∂E=δc~(t)⊙(1−tanhc^(t)⊙tanhc^(t))=:∂i^(t)∂E=δi(t)⊙i^(t)⊙(1−i^(t))=:∂f^(t)∂E=δf(t)⊙f^(t)⊙(1−f^(t))=:∂o^(t)∂E=δo(t)⊙o^(t)⊙(1−o^(t))=:∂z(t)∂E=[δc^(t)Tδi^(t)Tδf^(t)Tδo^(t)T]T=:∂s(t)∂E=MTδz(t)=δs(t)[:d]=δz(t)s(t)T=t=1∑TδM(t)
关于LSTM解决了什么问题,本文更认同这样一种观点:LSTM并没有解决梯度爆炸,也没有解决梯度消失问题。由上面推导过程可见反向传播过程中仍然有矩阵 M \boldsymbol{M} M相乘的操作,因此多个时刻之后理论上仍会出现远端的梯度消失/爆炸。LSTM的贡献在于使用如下两种方法更好地捕捉了长期依赖关系
LSTM另一个需要注意的点是,其在隐藏层间相互传递的信息与普通RNN不同。普通RNN的隐藏状态 h ( t ) \boldsymbol{h}^{(t)} h(t) (或者与前文一致记为 s ( t ) \boldsymbol{s}^{(t)} s(t)) 既会传递给下一个时刻,作为下一个时刻隐藏状态的输入之一,也会向上传递给输出层。而LSTM传递给输出层的信息不变,但是传给下一个时刻隐藏状态的内容又新加了一个单元状态 c ( t ) \boldsymbol{c}^{(t)} c(t)。因此
不严格地讲, h ( t ) \boldsymbol{h}^{(t)} h(t)是 t t t时刻单元"可以传播"的所有信息,因此能够输出和传递下去是无可厚非的;而 c ( t ) \boldsymbol{c}^{(t)} c(t)则是其所有了解的信息,有一些可能不适合在 t t t时刻输出,但是难免会在后面某个时刻用到,所以需要传递也是可以理解的。尽管如此,这两者可能仍然会给人造成迷惑,也许也正是门控循环单元 (Gated Recurrent Units, GRU) 提出的原因之一
GRU[Cho2014]的提出直接受到了LSTM的启发,目标是设计一种更容易计算和实现的隐藏单元。GRU将控制门的数量减少到了两个,具体结构如下
按照之前LSTM的简化方法,GRU的结构图可以简化成如下所示。此时更新门退化为一个单刀双掷开关,图中给出的是 z ( t ) \boldsymbol{z}^{(t)} z(t)为 1 \boldsymbol{1} 1的情况
GRU的重置门可以看做是起了LSTM输入门的功能,而更新门同时承担了LSTM遗忘门和输出门的责任,相当于输出门恒为 1 − z ( t ) \boldsymbol{1} - \boldsymbol{z}^{(t)} 1−z(t)
前面所述的RNN (包含了普通RNN的变种LSTM和GRU) 每个时刻的隐藏状态都是由之前所有时刻的状态所影响。这种网络结构用来建立语言模型是非常合适的,但是对于某些NLP任务,这样做实际上并没有捕捉足够的信息。例如假设要利用RNN做词性标注,则一个词对应的标签不仅与其前面的单词有关,也与其后面的单词有关----否则怎么很好地标注句首第一个单词呢?在这种情况下,就可以采用双向RNN这一结构。此时对每个时刻的输入 (在NLP领域里比较典型的是一个单词),模型不仅能看到足够远的所有过去状态,也可以看到足够远的所有未来状态
具体说,双向RNN会在原有RNN的基础上再加一个隐藏层,该隐藏层的状态是从后向前传播的,从序列的终点开始读取,因此称为后向层;而原有的从序列起点开始读取的隐藏层称为前向层。这样,对每个输入位置 i i i,双向RNN会获得两个独立状态,包括前向状态 h f ( i ) \boldsymbol{h}^{(i)}_f hf(i)和 h b ( i ) \boldsymbol{h}^{(i)}_b hb(i)。前者来自于 x ( 1 ) , … , x ( i ) \boldsymbol{x}^{(1)},\ldots,\boldsymbol{x}^{(i)} x(1),…,x(i),后者来自于 x ( i ) , … , x ( T ) \boldsymbol{x}^{(i)},\ldots, \boldsymbol{x}^{(T)} x(i),…,x(T) (假设输入序列长为 T T T)。在获得这两个独立的状态以后,双向RNN会将其合并作为最终的隐藏状态,传递给输出层。合并方法通常有两种,其一是简单将两个向量拼接,其二是学习一个线性变换将其转换为单向RNN隐藏状态的维度。后者常见于将多个双向RNN堆叠起来的模型
双向RNN的形式化描述如下。此处的双向状态合并采用了拼接法。注意两个隐藏层有不同的 U \boldsymbol{U} U、 W \boldsymbol{W} W和 V \boldsymbol{V} V
h → ( t ) = R N N F W ( h → ( t − 1 ) , x ( t ) ) h ← ( t ) = R N N B W ( h ← ( t + 1 ) , x ( t ) ) h = [ h → ( t ) ; h ← ( t ) ] \begin{aligned} \overrightarrow{\boldsymbol{h}}^{(t)} &= {\rm RNN_{FW}}\left(\overrightarrow{\boldsymbol{h}}^{(t-1)}, \boldsymbol{x}^{(t)}\right) \\ \overleftarrow{\boldsymbol{h}}^{(t)} &= {\rm RNN_{BW}}\left(\overleftarrow{\boldsymbol{h}}^{(t+1)}, \boldsymbol{x}^{(t)}\right) \\ \boldsymbol{h} &= \left[\overrightarrow{\boldsymbol{h}}^{(t)}; \overleftarrow{\boldsymbol{h}}^{(t)}\right] \end{aligned} h(t)h(t)h=RNNFW(h(t−1),x(t))=RNNBW(h(t+1),x(t))=[h(t);h(t)]
需要再次说明的是,只有能拿到整个输入句子时才能使用双向RNN
RNN可以有多个隐藏层,此时下层的隐藏状态的输出可以作为上层隐藏状态的输入。这样的RNN被称为堆叠RNN (stacked RNN) 或多层RNN (multi-layer RNN),此时底层RNN预计可以捕捉底层特征,高层RNN预计可以捕捉高层特征。下图给出了一个三层RNN作为示例
对应的模型形式化描述为
h 1 ( t ) = R N N 1 ( x ( t ) , h 1 ( t − 1 ) ) h 2 ( t ) = R N N 2 ( h 1 ( t ) , h 2 ( t − 1 ) ) h 3 ( t ) = R N N 3 ( h 2 ( t ) , h 3 ( t − 1 ) ) \begin{aligned} \boldsymbol{h}_{1}^{(t)} &= {\rm RNN}_1\left(\boldsymbol{x}^{(t)}, \boldsymbol{h}_1^{(t-1)}\right) \\ \boldsymbol{h}_{2}^{(t)} &= {\rm RNN}_2\left(\boldsymbol{h}_1^{(t)}, \boldsymbol{h}_2^{(t-1)}\right) \\ \boldsymbol{h}_{3}^{(t)} &= {\rm RNN}_3\left(\boldsymbol{h}_2^{(t)}, \boldsymbol{h}_3^{(t-1)}\right) \\ \end{aligned} h1(t)h2(t)h3(t)=RNN1(x(t),h1(t−1))=RNN2(h1(t),h2(t−1))=RNN3(h2(t),h3(t−1))
当堆叠多个双向RNN时,理想的实现方式如下图所示。此时第 l l l层正向RNN和反向RNN的输出会组合,分别成为第 l + 1 l+1 l+1层正向RNN和反向RNN各自的输入,以此类推
但是TensorFlow 1.x提供的APItf.nn.bidirectional_dynamic_rnn
并没有实现上图所示的结构,而是对于所有隐藏层,正向RNN的输出只作为下一层正向RNN的输入,反向RNN的输出只作为下一层反向RNN的输入,直到要进入输出层才组合,即是下图所示的结构
若要构造理想的多层双向RNN结构,需要使用tf.contrib.rnn.stack_bidirectional_dynamic_rnn
开源神经翻译框架Nematus [Sennrich2017] 实现的多层双向RNN结构更加奇特,是第 l l l层的正向RNN输出作为第 l + 1 l+1 l+1层反向RNN的输入,第 l l l层的反向RNN输出作为第 l + 1 l+1 l+1层正向RNN的输入,即
需要注意,当RNN被堆叠足够多层时,在网络的纵向也会发生前面介绍的梯度消失问题。[Britz2017]通过实验证明,使用编码器-解码器结构(在下一章介绍)训练神经翻译模型时,如果解码器超过8层,在不加其他手段干涉的情况下无法训出有意义的结果。这种问题在深层网络中尤为普遍,尤其是在CV领域堆叠了足够多的卷积层时。为了解决这种问题,一种常见的方法是引入残差连接 (residual connection) [He2016a]
在继续介绍残差连接的方法之前,有必要再探讨一下残差连接提出的背景。尽管如前所述,残差连接可以解决梯度消失问题,但是残差连接提出的目的并非如此。一方面,残差连接是首先用在CV领域,首先用在CNN结构中,而此时CNN已经基本都在使用ReLU做激活函数,基本可以解决梯度消失问题;另一方面,批归一化和He初始化方法也可以解决梯度消失问题。
那么残差连接要解决的是什么问题呢?作者发现随着(卷积)神经网络的 (文章中给出了56层的例子) 层数变大,神经网络的性能发生了严重的退化,其不仅测试误差变大,连训练误差也变大了。这种现象既不是过拟合 (否则训练误差应该非常小),也不是欠拟合 (毕竟让模型变得复杂以后理论上算力应该更强,但继续训练训练误差没有下降) ,其背后是深层网络可学习性变差的缘故。为了不让深层网络的效果还不如浅层网络,作者提出可让各层拟合一个残差映射:假设期望得到的底层映射是 H ( x ) \mathcal{H}(\boldsymbol{x}) H(x),则让堆叠起来的非线性变换学到 F ( x ) : = H ( x ) − x \mathcal{F}(\boldsymbol{x}) := \mathcal{H}(\boldsymbol{x}) - \boldsymbol{x} F(x):=H(x)−x,此时原始映射实际上是 F ( x ) + x \mathcal{F}(\boldsymbol{x}) + \boldsymbol{x} F(x)+x。原文认为这样有两个好处
如果最优映射不是恒等映射,那么拟合 F ( x ) \mathcal{F}(\boldsymbol{x}) F(x)和 H ( x ) \mathcal{H}(\boldsymbol{x}) H(x)难度差不多
如果最优映射是或者接近是恒等映射,那么拟合残差映射 F ( x ) \mathcal{F}(\boldsymbol{x}) F(x)更容易,让这些层学着什么也不做就行。
下图给出了一个残差学习块,可以看作是在原有结构的基础上使用恒等映射新增一条"短路连接",即将输入再加到原有的输出上。短路连接没有增加模型的参数,也没有增加模型的复杂性,比较经济,不过此时要求 F \mathcal{F} F和 x \boldsymbol{x} x的维度必须相同。如果这个条件不能满足,那么可以对短路连接再加入一个线性变换来对齐,此时残差块的输出是 F ( x ) + W s x \mathcal{F}(\boldsymbol{x}) + \boldsymbol{W}_s\boldsymbol{x} F(x)+Wsx
说到这里,很自然会有这么一个问题:如果浅层网络就足够管用,为什么还要费力引进这些奇技淫巧,只为了让深层网络的一些层什么也不做,走短路连接呢?一方面,原始论文通过实证,说明引入了残差连接的深层网络效果不仅好过了没有引入的退化版本,也好过了之前表现不错的浅层版本;另一方面,[Veit2016]认为引入残差连接后得到的深层网络 (残差网络) 可以看作是若干浅层网络的集成 (ensemble) 。例如下图给出了一个三层残差网络展开后的示意图,从图中可以看到从输入到输出有 O ( 2 n ) O(2^n) O(2n)条路径,每加入一个残差块都会让路径数翻倍,因此残差网络可以看做是很多条路径的集成
[Veit2016]在训练好残差网络后,在测试时随机删掉了若干层,或者打乱了若干层的顺序,发现这样捣乱以后对模型的影响是平滑变化的 (例如删除 1 , … , 20 1, \ldots, 20 1,…,20层时,模型误差平滑上升) ,而传统的深层网络如果经历这样的操作,效果会明显变差,因此各条路径之间的依赖关系并不强,这意味着残差网络更像是这些路径的集成。进一步实验还验证了在残差网络里长度在15左右的路径对梯度贡献更大。综合前面各种实验的结果,就可以得到文章提出的观点:残差网络是若干浅层网络的集成,效果也会更好
[Neubig2017]给出了在RNN上使用残差连接的示意图,如下所示
对应的描述为
h 1 ( t ) = R N N 1 ( x ( t ) , h 1 ( t − 1 ) ) + x ( t ) h 2 ( t ) = R N N 2 ( h 1 ( t ) , h 2 ( t − 1 ) ) + h 1 ( t ) h 3 ( t ) = R N N 3 ( h 2 ( t ) , h 3 ( t − 1 ) ) + h 2 ( t ) \begin{aligned} \boldsymbol{h}_{1}^{(t)} &= {\rm RNN}_1\left(\boldsymbol{x}^{(t)}, \boldsymbol{h}_1^{(t-1)}\right) + \boldsymbol{x}^{(t)} \\ \boldsymbol{h}_{2}^{(t)} &= {\rm RNN}_2\left(\boldsymbol{h}_1^{(t)}, \boldsymbol{h}_2^{(t-1)}\right) + \boldsymbol{h}_{1}^{(t)} \\ \boldsymbol{h}_{3}^{(t)} &= {\rm RNN}_3\left(\boldsymbol{h}_2^{(t)}, \boldsymbol{h}_3^{(t-1)}\right) + \boldsymbol{h}_{2}^{(t)} \\ \end{aligned} h1(t)h2(t)h3(t)=RNN1(x(t),h1(t−1))+x(t)=RNN2(h1(t),h2(t−1))+h1(t)=RNN3(h2(t),h3(t−1))+h2(t)
[Britz2017]还提供了另一种残差连接的实现方法,其核心思想来自于DenseNet [Huang2017],做法是对浅层 l l l,不止增加一条其到第 l + 1 l+1 l+1层的短路连接,而是将其与其后所有层连接起来。对于三层RNN,描述为
h 1 ( t ) = R N N 1 ( x ( t ) , h 1 ( t − 1 ) ) + x ( t ) h 2 ( t ) = R N N 2 ( h 1 ( t ) , h 2 ( t − 1 ) ) + x ( t ) + h 1 ( t ) h 3 ( t ) = R N N 3 ( h 2 ( t ) , h 3 ( t − 1 ) ) + x ( t ) + h 1 ( t ) + h 2 ( t ) \begin{aligned} \boldsymbol{h}_{1}^{(t)} &= {\rm RNN}_1\left(\boldsymbol{x}^{(t)}, \boldsymbol{h}_1^{(t-1)}\right) + \boldsymbol{x}^{(t)} \\ \boldsymbol{h}_{2}^{(t)} &= {\rm RNN}_2\left(\boldsymbol{h}_1^{(t)}, \boldsymbol{h}_2^{(t-1)}\right) + \boldsymbol{x}^{(t)} + \boldsymbol{h}_{1}^{(t)} \\ \boldsymbol{h}_{3}^{(t)} &= {\rm RNN}_3\left(\boldsymbol{h}_2^{(t)}, \boldsymbol{h}_3^{(t-1)}\right) + \boldsymbol{x}^{(t)} + \boldsymbol{h}_{1}^{(t)}\boldsymbol + \boldsymbol{h}_{2}^{(t)} \\ \end{aligned} h1(t)h2(t)h3(t)=RNN1(x(t),h1(t−1))+x(t)=RNN2(h1(t),h2(t−1))+x(t)+h1(t)=RNN3(h2(t),h3(t−1))+x(t)+h1(t)+h2(t)
实验表明对于深层RNN,DenseNet的效果比普通残差网络要更好
如前所述,训练时一批送入若干样本比每次只训练一条样本更有效率。但是这里存在一个问题,就是文本数据一般都不是定长的:一批数据里第一个句子可能只有三个词,第二个句子可能有20个,第三个可能有10个……等等
一个基本的应对方法是设置一个固定值。如果某个给定的句子长于这个值,则将这个句子截断;如果某个给定的句子短于这个值,那么就使用某个特殊的符号 (例如
) 补齐。无论是截断还是补齐,都有两种方法:在句子前面补齐/将句子前部截掉;或在句子后面补齐/将句子后部截掉。在补齐时需要注意一点,就是为了让补齐用的符号对任务的损失值没有贡献,需要对其权重做mask,即计算损失值时将这部分参数贡献的部分去掉
如果输入句子的长度分布不太集中,则有可能对大量句子补齐太多
。对应的策略是可以应用分桶机制,将要补齐的长度分成若干个桶,每个句子补到某个桶规定的长度即可。例如假设各桶长度是5、10、20、30,类似"I love you"这样的句子补齐到5个单词 (不考虑
/
的情况下加2个
) 就可以
由于RNN本身的特性,其很适用于对序列建模。Andrej Karpathy的The Unreasonable Effectiveness of Recurrent Neural Networks一文给出了RNN网络的五种模式,每种模式都对应若干典型任务。各模式可参见下图所示,其中红色矩形是输入向量,绿色是RNN隐藏层状态,蓝色是输出向量
从左到右各模式及典型任务分别为
[Koehn2017], Philipp Koehn. (2017, September). Statistical Machine Translation, Draft of Chapter 13: Neural Machine Translation
[Neubig2017], Graham Neubig. (2017, March). Neural Machine Translation and Sequence-to-Sequence Models: A Tutorial
[Pascanu2013], Pascanu, R., Mikolov, T., & Bengio, Y. (2013, February). On the difficulty of training recurrent neural networks. In International conference on machine learning, ICML 2013 (pp. 1310-1318).
Explaining and Illustrating Orthogonal Initialization for Recurrent Neural Networks
Colah, Understanding LSTM Networks
[Hochreiter1997], Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
[Gers2000], Gers, F. A., Schmidhuber, J. A., & Cummins, F. A. (2000). Learning to Forget: Continual Prediction with LSTM. Neural Computation, 12(10), 2451-2471.
Forward and backward pass of LSTM
Towser@知乎: RNN中学习长期依赖的三种机制
Towser@知乎: 从信息隐匿的角度谈LSTM:从Stack到Nest
张皓@知乎: 三次简化一张图:一招理解LSTM/GRU门控机制
LSTM如何来避免梯度弥散和梯度爆炸? - 过拟合的回答 - 知乎
[Cho2014] Cho, K., van Merrienboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP 2014) (pp. 1724-1734).
Written Memories: Understanding, Deriving and Extending the LSTM
[Schuster1997] Schuster, M., & Paliwal, K. K. (1997). Bidirectional recurrent neural networks. IEEE Transactions on Signal Processing, 45(11), pp. 2673-2681.
[Britz2017] Britz, D., Goldie, A., Luong, M. T., & Le, Q. (2017). Massive Exploration of Neural Machine Translation Architectures. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, EMNLP 2017 (pp. 1442-1451).
[He2016a] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, CVPR 2016 (pp. 770-778).
[Veit2016] Veit, A., Wilber, M. J., & Belongie, S. (2016). Residual networks behave like ensembles of relatively shallow networks. In Advances in neural information processing systems, NeurIPS 2016 (pp. 550-558).
[He2016b] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings in deep residual networks. In European conference on computer vision, ECCV 2016 (pp. 630-645). Springer, Cham.
[Huang2017] Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, CVPR 2017 (pp. 4700-4708).
[Sennrich2017] Sennrich, R., Firat, O., Cho, K., Birch, A., Haddow, B., Hitschler, J., … & Nadejde, M. (2017, April). Nematus: a Toolkit for Neural Machine Translation. In Proceedings of the Software Demonstrations of the 15th Conference of the European Chapter of the Association for Computational Linguistics (pp. 65-68).