有任何的书写错误、排版错误、概念错误等,希望大家包含指正。
在阅读本篇之前建议先学习:
RNN 讲解
LSTM 讲解
GRU 是循环神经网络的一种,和 LSTM 一样,是为了解决长期依赖问题。GRU 单元结构如下。
图 1 GRU 模型单元
其中,黄色矩形表示一层神经网络,包含权重和激活函数,矩形中的符号表明激活函数的类型, σ σ σ 对应 sigmoid 函数, tanh \tanh tanh 对应 tanh 函数;粉色(椭)圆表示逐元素操作,比如粉色(椭)圆中为乘号表明矩阵进行对应元素相乘(点乘)操作,加号表明对应位置相加,亦矩阵加法。
GRU 有两个门,重置门和更新门。重置门,对应图中的 [ h t − 1 , x t ] → r t [h_{t-1}, x_t]\to r_t [ht−1,xt]→rt ;更新门,对应图中的 [ h t − 1 , x t ] → z t [h_{t-1},x_t]\to z_t [ht−1,xt]→zt 。
r t = σ ( W h r h t − 1 + W x r x t + b r ) z t = σ ( W h z h t − 1 + W x z x t + b z ) h ~ t = tanh ( W h h ( r t ⊙ h t − 1 ) + W x h x t + b h ) h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ~ t r_t = \sigma (W_{hr} h_{t-1} + W_{xr} x_t + b_r) \\ z_t = \sigma (W_{hz} h_{t-1} + W_{xz} x_t + b_z) \\ \tilde h_t = \tanh \big( W_{hh}(r_t⊙h_{t-1})+ W_{xh}x_t + b_h \big) \\ h_t = z_t⊙h_{t-1} + (1-z_t) ⊙ \tilde h_t rt=σ(Whrht−1+Wxrxt+br)zt=σ(Whzht−1+Wxzxt+bz)h~t=tanh(Whh(rt⊙ht−1)+Wxhxt+bh)ht=zt⊙ht−1+(1−zt)⊙h~t
其实更严谨地来说,前向传播的公式应该表示成矩阵乘法:
r t = σ ( W r ⋅ [ h t − 1 , x t ] ) z t = σ ( W z ⋅ [ h t − 1 , x t ] ) h ~ t = tanh ( W h ⋅ [ r t ⊙ h t − 1 , x t ] ) h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ~ t r_t = \sigma(W_r·[h_{t-1}, x_t]) \\ z_t = \sigma (W_z·[h_{t-1}, x_t]) \\ \tilde h_t = \tanh (W_h · [r_t⊙h_{t-1}, x_t]) \\ h_t = z_t⊙h_{t-1} + (1-z_t) ⊙ \tilde h_t rt=σ(Wr⋅[ht−1,xt])zt=σ(Wz⋅[ht−1,xt])h~t=tanh(Wh⋅[rt⊙ht−1,xt])ht=zt⊙ht−1+(1−zt)⊙h~t
其中忽略了偏置。
两组公式本质上的等价的。比如在第二组公式中, h t − 1 h_{t-1} ht−1 是 1 × 2 1\times 2 1×2 矩阵, x t x_t xt 是 1 × 3 1\times 3 1×3 矩阵,那么 [ h t − 1 , x t ] [h_{t-1},x_t] [ht−1,xt] 为 1 × 5 1\times 5 1×5 矩阵;又因为 h t h_t ht 由按位点乘计算得来,所以 z t z_t zt、 h t h_t ht 和 h ~ t \tilde h_t h~t 三者同型,故 W r W_r Wr、 W x W_x Wx 和 W h W_h Wh 均为 5 × 2 5\times 2 5×2 矩阵。等价地,在第一组公式中, h t − 1 h_{t-1} ht−1 和 x t x_t xt 均为 1 × 5 1\times 5 1×5 矩阵,只不过 h t − 1 h_{t-1} ht−1 的后三维元素为 0 0 0, x t x_{t} xt 的前两维元素为 0 0 0; W h r W_{hr} Whr、 W h z W_{hz} Whz 和 W h h W_{hh} Whh 是 5 × 2 5\times 2 5×2 矩阵,后三行元素为 0 0 0; W x r W_{xr} Wxr、 W x z W_{xz} Wxz 和 W x h W_{xh} Wxh 是 5 × 2 5\times 2 5×2 矩阵,前两行元素为 0 0 0。如此,在不考虑偏置的前提下,二者完全等价。
GRU 总体结构与 RNN 相近,但 GRU 的思想却与 LSTM 更加相似。
当 r t = 1 r_t = 1 rt=1, z t = 0 z_t=0 zt=0 时,根据前向传播公式可知,GRU 退化成 RNN,这意味着此时的 GRU 失去捕捉长期依赖关系的能力。
**GRU 与 LSTM 在思想上均采用“门”控制历史信息的遗忘程度和输入信息的保留程度,从而实现了 RNN 不具备的捕获长期依赖关系的能力。**对于具体单元结构而言,GRU 和 LSTM 最大的区别在于 GRU 将 LSTM 中的当前单元状态 c c c 和输出 h h h 的信息融合在一起,用 h h h 表示,也就是说在 GRU 中 h h h 包含了 LSTM 的两部分信息。另外,有很多处小局部相似,但都不足为道,这里仅简单对比二者最相似的部件。正如很多资料所说,GRU 的更新门是 LSTM 的遗忘门和输入门的融合。如图 2 2 2 所示,红色部分和蓝色部分分别为两个模型相似的部件,如果按照上面信息混合的思想对 GRU 中的 h h h 进行理解,那么LSTM 的遗忘门周围的部件确实可以与 GRU 的更新门周围的部件相对应,但是由于 GRU 中 h ~ \tilde h h~ 的获取与重置门 r r r 相关,所以这部分其实是无法与 LSTM 的输入门完全对应的,仅看部件结构还是比较相似的。所以说,详细地对比 GRU 和 LSTM 的单元结构得到的收益很少,更推荐从整体(思想)上把握二者的相似性。
图 2 LSTM (左)与 GRU (右)对比
分析重置门。在 RNN 的讲解中我们提到过,RNN 无法实现长期信息传递,这主要是因为在计算偏导后发现,随着时间的流逝,历史信息的贡献越来越小,最终几乎消失。RNN 总共涉及两个公式,一个历史状态和输入共同决定当前状态的计算公式,另一个是当前状态决定输出的计算公式,涉及到历史信息的仅有当前状态的计算公式 s t = f ( W s t − 1 + U x t ) s_t = f(Ws_{t-1}+Ux_t) st=f(Wst−1+Uxt),其中 f f f 为激活函数。可见,此公式无法实现长期信息传递。对应于 GRU 的公式 h ~ t = tanh ( W h h ( r t ⊙ h t − 1 ) + W x h x t ) \tilde h_t = \tanh \big( W_{hh}(r_t⊙h_{t-1})+ W_{xh}x_t \big) h~t=tanh(Whh(rt⊙ht−1)+Wxhxt), h ~ t \tilde h_t h~t 的公式仅仅是对 s t s_t st 公式的推广,当 r t = 1 r_t=1 rt=1 时两个公式等价,显然,在历史信息 h t − 1 h_{t-1} ht−1 前加上系数 r t r_t rt 不会影响该公式无法用于控制信息长期传递的性质,换句话说,GRU 中的 h ~ t \tilde h_t h~t 公式不起到控制信息长期传递的作用。进一步, r t r_t rt 也就与长期信息传递无关,它仅用于调节输入信息 x t x_t xt 的占比,间接地控制短期信息的传递量。可见,重置门的值越小,输入信息保留的越多。重置门有助于捕捉时间序列里的短期依赖关系。
分析更新门。仅仅添加重置门依然无法控制长期信息的传递,因此需要添加更新门来处理长期信息。在 LSTM 的讲解中提到,当前单元状态 c t c_t ct 由公式 c t = c t − 1 ⊙ f t + g t ⊙ i t c_t = c_{t-1}⊙f_t+ g_t⊙i_t ct=ct−1⊙ft+gt⊙it 决定,其中 f t f_t ft 为遗忘门对应由 0 ∼ 1 0\sim1 0∼1 构成的矩阵,用于控制历史信息 c t − 1 c_{t-1} ct−1 的遗忘程度, g t g_t gt 用于控制当前输入的保留程度。此公式中的遗忘门真正起到了控制历史信息遗忘程度的作用,保证了 LSTM 能够捕获长期依赖关系,是 LSTM 捕获长期依赖关系的关键。对应 GRU 的公式 h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ~ t h_t = z_t⊙h_{t-1} + (1-z_t) ⊙ \tilde h_t ht=zt⊙ht−1+(1−zt)⊙h~t, z t z_t zt 越接近 1 1 1,历史信息 h t − 1 h_{t-1} ht−1 的保留程度越大。与 h ~ t \tilde h_t h~t 的公式相比,尽管 h ~ t \tilde h_t h~t 公式也涉及历史信息,但是从偏导中中我们发现, h ~ t \tilde h_t h~t 公式中的历史信息无法向后长期传递,所以需要一个新公式保证历史信息长期向后传递,这个有效的新公式就是 h t h_t ht 公式,此公式直接用重置门 z t z_t zt 控制历史信息的保留程度。更新门的值越大,流入的历史信息越多。更新门有助于捕捉时间序列里的长期依赖关系。
[1] 56 门控循环单元(GRU)【动手学深度学习v2】- bilibili
[2] DL之GRU:GRU算法相关论文、建立过程(基于TF)、相关思路配图集合、TF代码实现 - CSDN
[3] 【机器学习】详解 GRU - CSDN
[4] 【机器学习】RNN 讲解 - CSDN
[5] 【机器学习】LSTM 讲解 - CSDN