
1. 模型定义

门控循环单元网络(Gated Recurrent Unit,GRU)1是在LSTM基础上发展而来的一种简化变体,它通常能以更快的计算速度达到与LSTM模型相似的效果2

2. 模型结构与前向传播公式

GRU模型的隐藏状态计算模块不引入额外的记忆单元,且将逻辑门简化为重置门reset gate)和更新门update gate),其结构示意图及前向传播公式如下所示:


{ 输 入 : X t ∈ R m × d ,      H t − 1 ∈ R m × h 重 置 门 : R t = σ ( X t W x r + H t − 1 W h r + b r ) , W x r ∈ R d × h ,    W h r ∈ R h × h 候 选 隐 藏 状 态 : H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) ,    W x h ∈ R d × h ,    W h h ∈ R h × h 更 新 门 : Z t = σ ( X t W x z + H t − 1 W h z + b z ) , W x z ∈ R d × h ,    W h z ∈ R h × h 隐 藏 状 态 : H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t 输 出 : Y ^ t = H t W h y + b y , W h y ∈ R h × q 损 失 函 数 : L = ∑ t = 1 T l ( ( ^ Y ) t , Y t ) (2.1) \begin{cases} 输入: & X_t \in R^{m \times d}, \ \ \ \ H_{t-1} \in R^{m \times h} \\ \\ 重置门: & R_t = \sigma(X_tW_{xr} + H_{t-1}W_{hr} + b_r), & W_{xr} \in R^{d \times h},\ \ W_{hr} \in R^{h \times h} \\ \\ 候选隐藏状态: & \tilde{H}_t = tanh(X_tW_{xh} + (R_t \odot H_{t-1})W_{hh} + b_h),\ \ & W_{xh} \in R^{d \times h},\ \ W_{hh} \in R^{h \times h} \\ \\ 更新门: & Z_t = \sigma(X_tW_{xz} + H_{t-1}W_{hz} + b_z), & W_{xz} \in R^{d \times h},\ \ W_{hz} \in R^{h \times h} \\ \\ 隐藏状态: & H_t = Z_t \odot H_{t-1} + (1-Z_t) \odot \tilde{H}_t \\ \\ 输出: & \hat{Y}_t = H_tW_{hy} + b_y, & W_{hy} \in R^{h \times q} \\ \\ 损失函数: & L = \sum_{t=1}^{T} l(\hat(Y)_t, Y_t) \end{cases} \tag{2.1} XtRm×d,    Ht1Rm×hRt=σ(XtWxr+Ht1Whr+br),H~t=tanh(XtWxh+(RtHt1)Whh+bh),  Zt=σ(XtWxz+Ht1Whz+bz),Ht=ZtHt1+(1Zt)H~tY^t=HtWhy+by,L=t=1Tl((^Y)t,Yt)WxrRd×h,  WhrRh×hWxhRd×h,  WhhRh×hWxzRd×h,  WhzRh×hWhyRh×q(2.1)

3. GRU的反向传播过程


∂ L ∂ Y ^ t = ∂ l ( Y ^ t , Y t ) T ⋅ ∂ Y ^ t (3.1) \frac{\partial L}{\partial \hat{Y}_t} = \frac{\partial l(\hat{Y}_t, Y_t)}{T \cdot\partial \hat{Y}_t} \tag {3.1} Y^tL=TY^tl(Y^t,Yt)(3.1)

∂ L ∂ Y ^ t ⇒ { ∂ L ∂ W h y = ∂ L ∂ Y ^ t ∂ Y ^ t ∂ W h y ∂ L ∂ H t = { ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t , t = T ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t + ∂ L ∂ H t + 1 ∂ H t + 1 ∂ H t , t < T (3.2) \frac{\partial L}{\partial \hat{Y}_t} \Rightarrow \begin{cases} \frac{\partial L}{\partial W_{hy}} = \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial W_{hy}} \\ \\ \frac{\partial L}{\partial H_t} = \begin{cases} \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t}, & t=T \\ \\ \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t} + \frac{\partial L}{\partial H_{t+1}}\frac{\partial H_{t+1}}{\partial H_{t}}, & tY^tLWhyL=Y^tLWhyY^tHtL=Y^tLHtY^t,Y^tLHtY^t+Ht+1LHtHt+1,t=Tt<T(3.2)

{ ∂ L ∂ W x z = ∂ L ∂ Z t ∂ Z t ∂ W x z ∂ L ∂ W h z = ∂ L ∂ Z t ∂ Z t ∂ W h z ∂ L ∂ b z = ∂ L ∂ Z t ∂ Z t ∂ b z { ∂ L ∂ W x h = ∂ L ∂ H ~ t ∂ H ~ t ∂ W x h ∂ L ∂ W h h = ∂ L ∂ H ~ t ∂ H ~ t ∂ W h h ∂ L ∂ b h = ∂ L ∂ H ~ t ∂ H ~ t ∂ b h (3.3) \begin{matrix} \begin{cases} \frac{\partial L}{\partial W_{xz}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial W_{xz}} \\ \\ \frac{\partial L}{\partial W_{hz}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial W_{hz}} \\ \\ \frac{\partial L}{\partial b_{z}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial b_{z}} \end{cases} & & & & \begin{cases} \frac{\partial L}{\partial W_{xh}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial W_{xh}} \\ \\ \frac{\partial L}{\partial W_{hh}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial W_{hh}} \\ \\ \frac{\partial L}{\partial b_{h}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial b_{h}} \end{cases} \end{matrix} \tag {3.3} WxzL=ZtLWxzZtWhzL=ZtLWhzZtbzL=ZtLbzZtWxhL=H~tLWxhH~tWhhL=H~tLWhhH~tbhL=H~tLbhH~t(3.3)

{ ∂ L ∂ W x r = ∂ L ∂ R t ∂ R t ∂ W x r ∂ L ∂ W h r = ∂ L ∂ R t ∂ R t ∂ W h r ∂ L ∂ b r = ∂ L ∂ R t ∂ R t ∂ b r (3.4) \begin{cases} \frac{\partial L}{\partial W_{xr}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial W_{xr}} \\ \\ \frac{\partial L}{\partial W_{hr}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial W_{hr}} \\ \\ \frac{\partial L}{\partial b_{r}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial b_{r}} \end{cases} \tag {3.4} WxrL=RtLWxrRtWhrL=RtLWhrRtbrL=RtLbrRt(3.4)

与LSTM同理,GRU反向传播公式求解的关键也是对不同时间步间(传递)梯度 的求解,其方法与LSTM一致本文不再赘述。且同样我们也可以得出定性结论,GRU缓解长期依赖问题的原理与LSTM类似,都是通过高阶幂次项乘数调控和添加低阶幂次项实现。其中,重置门有助于捕获序列中的短期依赖关系,更新门有助于捕获序列中的长期依赖关系。(具体请详见作者文章:时序模型:长短期记忆网络(LSTM)中的证明过程)

4. 模型的代码实现

4.1 TensorFlow 框架实现

4.2 Pytorch 框架实现

v2.0 修复RNN参数初始化不当,引起的时间步传播梯度消失问题。   2022.04.28
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter

class LSTM_Cell(nn.Module):
    # noinspection PyTypeChecker
    def __init__(self, token_dim, hidden_dim
                 , input_act=nn.ReLU()
                 , forget_act=nn.ReLU()
                 , output_act=nn.ReLU()
                 , hatcell_act=nn.Tanh()
                 , hidden_act=nn.Tanh()
                 , device="cpu"):
        self.hidden_dim = hidden_dim
        self.device = device
        self.InputG = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=input_act, device=device
        self.ForgetG = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=forget_act, device=device
        self.OutputG = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=output_act, device=device
        self.HatCell = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=hatcell_act, device=device
        self.HiddenActivation = hidden_act.to(self.device)

    def forward(self, inputs, last_state):
        inputs:      it is the word vector of this time step token.
        last_state:  last_state = [last_cell, last_hidden_state]
        Ig = self.InputG(
            inputs, last_state
        Fg = self.ForgetG(
            inputs, last_state
        Og = self.OutputG(
            inputs, last_state
        hat_cell = self.HatCell(
            inputs, last_state
        cell = Fg * last_state[0] + Ig * hat_cell
        hidden = Og * self.HiddenActivation(cell)
        return [cell, hidden]

    def zero_initialization(self, batch_size):
        init_cell = torch.zeros([batch_size, self.hidden_dim]).to(self.device)
        init_state = torch.zeros([batch_size, self.hidden_dim]).to(self.device)
        return [init_cell, init_state]

class RNN_Layer(nn.Module):
    bidirectional:  If ``True``, becomes a bidirectional RNN network. Default: ``False``.
    padding:        String, 'pre' or 'post' (optional, defaults to 'pre'): pad either before or after each sequence.
    def __init__(self, rnn_cell, bidirectional=False, pad_position='post'):
        self.RNNCell = rnn_cell
        self.bidirectional = bidirectional
        self.padding = pad_position

    def forward(self, inputs, mask=None, initial_state=None):
        inputs:   it's shape is [batch_size, time_steps, token_dim]
        mask:     it's shape is [batch_size, time_steps]
        sequence:    it is hidden state sequence, and its' shape is [batch_size, time_steps, hidden_dim]
        last_state: it is the hidden state of input sequences at last time step,
                    but, attentively, the last token wouble be a padding token,
                    so this last state is not the real last state of input sequences;
                    if you want to get the real last state of input sequences, please use utils.get_rnn_last_state(hidden state sequence).
        batch_size, time_steps, token_dim = inputs.shape
        if initial_state is None:
            initial_state = self.RNNCell.zero_initialization(batch_size)
        if mask is None:
            if batch_size == 1:
                mask = torch.ones([1, time_steps]).to(inputs.device.type)
            elif self.padding == 'pre':
                raise ValueError('请给定掩码矩阵(mask)')
            elif self.padding == 'post' and self.bidirectional is True:
                raise ValueError('请给定掩码矩阵(mask)')

        # 正向时间步循环
        hidden_list = []
        hidden_state = initial_state
        last_state = None
        for i in range(time_steps):
            hidden_state = self.RNNCell(inputs[:, i], hidden_state)
            if i == time_steps - 1:
                last_state = hidden_state
            if self.padding == 'pre':
                """如果padding值填充在序列尾端,则正向时间步传播应加 mask 操作"""
                hidden_state = [
                    hidden_state[j] * mask[:, i:i + 1] + initial_state[j] * (1 - mask[:, i:i + 1])  # 重新初始化(加数项作用)
                    for j in range(len(hidden_state))
        sequence = torch.reshape(
                torch.concat(hidden_list, dim=1)
                , dim=1)
            , [batch_size, time_steps, -1]

        # 反向时间步循环
        if self.bidirectional is True:
            hidden_list = []
            hidden_state = initial_state
            for i in range(time_steps, 0, -1):
                hidden_state = self.RNNCell(inputs[:, i - 1], hidden_state)
                hidden_list.insert(0, hidden_state[-1])
                if i == time_steps:
                    last_state = [
                        torch.concat([last_state[j], hidden_state[j]], dim=1)
                        for j in range(len(hidden_state))
                if self.padding == 'post':
                    """如果padding值填充在序列首端,则正反时间步传播应加 mask 操作"""
                    hidden_state = [
                        hidden_state[j] * mask[:, i - 1:i] + initial_state[j] * (1 - mask[:, i - 1:i])  # 重新初始化(加数项作用)
                        for j in range(len(hidden_state))
            sequence = torch.concat([
                        torch.concat(hidden_list, dim=1)
                        , dim=1)
                    , [batch_size, time_steps, -1]
            ], dim=-1)

        return sequence, last_state

  1. Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259. ↩︎

  2. Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555. ↩︎
