神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览

神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览

    • RNN cell的实现
      • `keras.layers.Layer`
      • `layers.Layer`
      • `nn.rnn_cell.RNNCell`
      • `LayerRNNCell`
      • `BasicRNNCell`
      • `GRUCell`
      • `BasicLSTMCell`
      • `LSTMCell`
      • `MultiRNNCell`
    • RNN的实现
      • 静态机制
      • 动态机制
    • 参考文献

本文主要讨论TF1.14对RNN的实现。尽管更老的版本实现可能有差别,但考虑其有些过时,因此这里略过

RNN cell的实现

TF1.x对RNN cell的实现总体遵循如下所示之类图
神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览_第1张图片

keras.layers.Layer

代码位置:tensorflow/python/keras/engine/base_layer.py

该类是所有层(“layers”)的基类。按照文档,所谓“层”实现了神经网络的一些常用操作。既然如此,那么一个“层”应该有如下特点:

  • 首先,既然定义了“操作”,那么肯定会有一个“宾语”。这意味着每个层都会接收输入,作为操作对象
  • 其次,“操作”本身一定有逻辑,因此每个层内部会定义该层对输入应如何处理,执行怎样的计算过程。通常情况下,每个层都需要设置一些可训练的参数
  • 最后,该操作通常会产生输出,将输出传递给下一个层,作为下一个层的输入

由于层的核心是计算过程,因此在这样的基类中,最核心的方法是覆写的__call__方法,使得所有继承该类的子类对象是可调用的。此类已经对__call__方法做了比较好的包装,因此所有具体的层只需要实现该类提供的接口buildcall即可。这两个接口会在__call__方法中调用。基类实现的__call__方法逻辑大致为

def __call__(inputs):
    将inputs中所有numpy类型数据转换为tensor
    build_graph = 判断是否建图
    previous_mask = 从前层获取mask值
    with base_layer_utils.call_context(build_graph):
        if build_graph:
            判断输入是否满足input_spec
        # 下述核心逻辑在静态图(build_graph is True)和eager模式下是相同的
        # 但因为两者背后实现逻辑有区别,所以原始代码分开做了实现。这里做了合并
        self._maybe_build(inputs)  # 实际就是调用build。该方法会在开始检查self._built是否为True
                                   # 如果为True直接返回。在最后会显式将self._built设置为True
        outputs = self.call(inputs)
        # 对输出做正则,并将正则损失加到loss中
        self._handle_activity_regularization(inputs, outputs)
        # 计算并设置mask
        self._set_mask_metadata(inputs, outputs, previous_mask)
        return outputs

此外,子类一般还需要实现__init__方法。子类所需要实现的三个方法大致分工如下

  • __init__来创建并初始化(一部分)成员变量,但是不指定训练参数
  • build内通常调用add_weight方法,根据输入、类型、形状(shape)、使用的初始化方法等信息创建要训练的参数
  • call内部实现具体的计算逻辑,返回outputs

如果损失计算或权重更新与输入有关,可通过让子类实现add_lossadd_update以达到此目的

layers.Layer

代码位置:tensorflow/python/layers/base.py

该类存在的主要目的是向下兼容静态图模式的代码。在TF1.x的早期版本(至少到1.5)该类事实上是所有“层”类的基类(直接继承自object),不过在1.x的后期版本该类的核心逻辑已被移动到前述tf.keras.layers.Layer中,官方也不再推荐开发者继承该类开发

nn.rnn_cell.RNNCell

代码位置:tensorflow/python/ops/rnn_cell_impl.py。至另有标注为止,后续各类的实现均在该文件中

该类为所有具体的RNN实现提供了一个共同的抽象表示。通过重写__call__方法,将其签名改为__call__(self, inputs, state, scope=None),使得该类的所有子类对象以函数的方式被“调用”时,参数除了inputs以外还需要再带一个state参数(通常是上一个时间步传递来的状态)。对应地,具体实现call时也需要传入这两个参数(当然,实际上仍然可以只写明传入inputs,而state通过args, **kwargs传入。但是这样看上去像是在杠)。此外,该类还提供了get_initial_statezero_state方法,后者常用来初始化RNN的初始状态

LayerRNNCell

为了向layers.Layer靠拢插进来的新类。其作用是将变量创建(build该做的事情)从call中剥离出来。具体的解(tu)释(cao)可以参看为什么感觉tensorflow的源码写的很多余? - Towser的回答 - 知乎

BasicRNNCell

BasicRNNCell实现的是没有任何门控的Vanilla RNN。我一直不很倾向在文字里放大段原始代码,但是由于该类背后逻辑比较简单,实现相对简短,因此这里将会给出原始实现的具体细节

__init__函数设置自身的input_spec、单元数num_units和激活函数种类,默认激活函数为tanh。如代码文档所说,这里的“cell”和文献里的“cell”不同。文献里的cell输出一个标量,但是这里的一个cell相当于一组文献中的cell,共num_units个。考虑之前文章贴的示意图
神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览_第2张图片
其展开图如下
神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览_第3张图片这里RNN cell的num_units即为3(每个时间步黄色圆圈的数量)
__init__的具体实现如下:

def __init__(self, num_units, activation=None, reuse=None, 
             name=None, dtype=None, **kwargs):
    super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, 
                                       dtype=dtype, **kwargs)
    # 只接受浮点型或复数浮点型输入
    _check_supported_dtypes(self.dtype)
    if context.executing_eagerly() and context.num_gpus() > 0:
        logging.warn("建议使用tf.contrib.cudnn_rnn.CudnnRNNTanh以达到更好性能")

    # 要求输入是2维。由本文前面介绍,会在调用__call__时检查input spec
    self.input_spec = input_spec.InputSpec(ndim=2)

    self._num_units = num_units
    if activation:
        self._activation = activations.get(activation)
    else:
        self._activation = math_ops.tanh

build只是通过最基类tf.Layers提供的add_weight方法来注册变量。add_weight在调用时会访问初始化时提供的信息,例如参数使用何种初始化方法、是否可训练等,并将参数加入图中(静态图模式下)。build的具体实现为

def build(self, inputs_shape):
    if inputs_shape[-1] is None:
        raise ValueError
    _check_supported_dtypes(self.dtype)

    # 与词向量连接的那一层,input_depth就是词向量维度
    input_depth = input_shape[-1]
    # add_variable就是add_weight的别名
    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + self._num_units, self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))
    self.built = True

call则是实现计算逻辑
o u t p u t = s t a t e = a c t i v a t e ( x ( t ) U + h ( t − 1 ) W + b ) {output} = {state} = {\rm activate}\left(\boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W}+\boldsymbol{b}\right) output=state=activate(x(t)U+h(t1)W+b)
其中

  • x t \boldsymbol{x}_t xt形状为 b a t c h _ s i z e × i n p u t _ d e p t h batch\_size \times input\_depth batch_size×input_depth,对应代码中的inputs(这里使用小写,是取batch size为1时的情况。此时输入是一个行向量)
  • h t − 1 \boldsymbol{h}_{t-1} ht1形状为 b a t c h _ s i z e × n u m _ u n i t s batch\_size \times num\_units batch_size×num_units,对应代码中的state
  • U \boldsymbol{U} U形状为 i n p u t _ d e p t h × n u m _ u n i t s input\_depth \times num\_units input_depth×num_units
  • W \boldsymbol{W} W形状为 n u m _ u n i t s × n u m _ u n i t s num\_units \times num\_units num_units×num_units

(这里矩阵乘法顺序与理论部分有区别,原因是理论部分将输入向量看做列向量,而这里使用行向量实现。下同)

真正实现时,将 U \boldsymbol{U} U W \boldsymbol{W} W“纵向”连接在一起(这个矩阵的整体为self._kernel), x ( t ) \boldsymbol{x}^{(t)} x(t) h ( t − 1 ) \boldsymbol{h}^{(t-1)} h(t1)按“横向”连接在一起。这样仅通过一次矩阵乘法运算,就能算出结果。即
[ x ( t ) h ( t − 1 ) ] ⋅ [ U W ] = x ( t ) U + h ( t − 1 ) W \left[\begin{matrix}\boldsymbol{x}^{(t)} & \boldsymbol{h}^{(t-1)}\end{matrix}\right] \cdot \left[\begin{matrix}\boldsymbol{U} \\ \boldsymbol{W} \end{matrix}\right] = \boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W} [x(t)h(t1)][UW]=x(t)U+h(t1)W
这也是为什么在buildself._kernel的形状为[input_depth + self._num_units, self._num_units]的原因

具体代码如下

def call(self, inputs, state):
    _check_rnn_cell_input_dtypes([inputs, state])
    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    output = self._activation(gate_inputs)
    # 实际为output, state。但是vanilla RNN不区分
    return output, output

GRUCell

GRUCell和后面要介绍的LSTMCell实现思路与前述BasicRNNCell大同小异,也是按照文献按部就班实现各个门控,同时尽量减少矩阵加法,将若干矩阵拼成一个大矩阵,通过一个矩阵乘法完成任务。因此这里不再附录具体代码

GRUCell内部有四个参数

  • _gate_kernel,形状为 ( i n p u t _ d e p t h + n u m _ u n i t s ) × ( 2 × n u m _ u n i t s ) (input\_depth + num\_units) \times (2 \times num\_units) (input_depth+num_units)×(2×num_units)
  • _gate_bias,一个维度为 2 × n u m _ u n i t s 2 \times num\_units 2×num_units的行向量
  • _candidate_kernel,形状为 ( i n p u t _ d e p t h + n u m _ u n i t s ) × n u m _ u n i t s (input\_depth + num\_units) \times num\_units (input_depth+num_units)×num_units
  • _candidate_bias,一个维度为 n u m _ u n i t s num\_units num_units的行向量

回忆GRU有两个门

  • 重置门: r ( t ) = σ ( h ( t − 1 ) W r + x ( t ) U r + b r ) \boldsymbol{r}^{(t)} = \sigma\left(\boldsymbol{h}^{(t-1)}\boldsymbol{W}_r + \boldsymbol{x}^{(t)}\boldsymbol{U}_r+\boldsymbol{b}_r\right) r(t)=σ(h(t1)Wr+x(t)Ur+br)
  • 更新门: z ( t ) = σ ( h ( t − 1 ) W z + x ( t ) U z + b z ) \boldsymbol{z}^{(t)} = \sigma\left(\boldsymbol{h}^{(t-1)}\boldsymbol{W}_z + \boldsymbol{x}^{(t)}\boldsymbol{U}_z+\boldsymbol{b}_z\right) z(t)=σ(h(t1)Wz+x(t)Uz+bz)

在实现时有如下合并
σ ( [ i n p u t s , s t a t e ] × _ g a t e _ k e r n e l + _ g a t e _ b i a s ) = σ ( [ x ( t ) h ( t − 1 ) ] ⋅ [ U r U z W r W z ] + [ b r b z ] ) = [ r ( t ) z ( t ) ] : = [ r , u ] \begin{aligned} &\sigma\left({\tt [inputs, state]} \times {\tt \_gate\_kernel} + {\tt \_gate\_bias}\right) \\ = &\sigma\left(\left[\begin{matrix}\boldsymbol{x}^{(t)} & \boldsymbol{h}^{(t-1)}\end{matrix}\right] \cdot \left[\begin{matrix}\boldsymbol{U}_r & \boldsymbol{U}_z \\ \boldsymbol{W}_r & \boldsymbol{W}_z \end{matrix}\right] + \left[\begin{matrix}\boldsymbol{b}_r & \boldsymbol{b}_z\end{matrix}\right]\right) \\ =& \left[\begin{matrix}\boldsymbol{r}^{(t)} & \boldsymbol{z}^{(t)}\end{matrix}\right] \\ :=& {\tt [r, u]} \end{aligned} ==:=σ([inputs,state]×_gate_kernel+_gate_bias)σ([x(t)h(t1)][UrWrUzWz]+[brbz])[r(t)z(t)][r,u]

理论部分“新内容” h ~ ( t ) \tilde{\boldsymbol{h}}^{(t)} h~(t)的计算方式实现与BasicRNNCell中的call实现类似,只不过需要先将r与状态state做逐元素相乘,即r * state。最后,实现里最后一步新状态的计算与原始文献不同,前一时间步状态的系数不是1-u,而直接是u

BasicLSTMCell

BasicGRUCell思路类似,使用一个矩阵和一个偏置向量学习三个门 i ( t ) \boldsymbol{i}^{(t)} i(t) f ( t ) \boldsymbol{f}^{(t)} f(t) o ( t ) \boldsymbol{o}^{(t)} o(t)和“新状态” c ~ ( t ) \tilde{\boldsymbol{c}}^{(t)} c~(t)的参数:

  • _kernel,形状为 ( i n p u t _ d e p t h + n u m _ u n i t s ) × ( 4 × n u m _ u n i t s ) (input\_depth + num\_units) \times (4 \times num\_units) (input_depth+num_units)×(4×num_units)
  • _bias,一个维度为 4 × n u m _ u n i t s 4 \times num\_units 4×num_units的行向量

三个门使用的激活函数是sigmoid,而 c ~ ( t ) \tilde{\boldsymbol{c}}^{(t)} c~(t)(对应于代码中为j)使用的与它们不同,是tanh,因此原始实现中将这四个结果的激活计算推迟到了计算 c ( t ) \boldsymbol{c}^{(t)} c(t) h ( t ) \boldsymbol{h}^{(t)} h(t)时调用(代码中记为new_cnew_h)。此外,为了避免在训练开始时网络过多遗忘,tf对遗忘门算出的结果加了一个默认为1的偏置项forget_bias(这是与主流LSTM原理不同的一点)。即整套逻辑为
c , h = s t a t e σ ( [ i n p u t s , h ] × _ k e r n e l + _ b i a s ) = σ ( [ x ( t ) h ( t − 1 ) ] ⋅ [ U i U c U f U o W i W c W f W o ] + [ b i b c b f b o ] ) = [ i ( t ) c ~ ( t ) f ( t ) o ( t ) ] : = [ i , j , f , o ] n e w _ c = c ⊙ σ ( f + f o r g e t _ b i a s ) + σ ( i ) ⊙ tanh ⁡ ( j ) n e w _ h = tanh ⁡ ( n e w _ c ) ⊙ σ ( o ) \begin{aligned} &{\tt c, h = state} \\ &\sigma\left({\tt [inputs, h]} \times {\tt \_kernel} + {\tt \_bias}\right) \\ = &\sigma\left(\left[\begin{matrix}\boldsymbol{x}^{(t)} & \boldsymbol{h}^{(t-1)}\end{matrix}\right] \cdot \left[\begin{matrix}\boldsymbol{U}_i & \boldsymbol{U}_c & \boldsymbol{U}_f & \boldsymbol{U}_o \\ \boldsymbol{W}_i & \boldsymbol{W}_c & \boldsymbol{W}_f & \boldsymbol{W}_o \end{matrix}\right] + \left[\begin{matrix}\boldsymbol{b}_i & \boldsymbol{b}_c & \boldsymbol{b}_f & \boldsymbol{b}_o \end{matrix}\right]\right) \\ =& \left[\begin{matrix}\boldsymbol{i}^{(t)} & \tilde{\boldsymbol{c}}^{(t)} & \boldsymbol{f}^{(t)} & \boldsymbol{o}^{(t)}\end{matrix}\right] \\ :=& {\tt [i, j, f, o]} \\ &{\tt new\_c = c \odot \sigma(f + forget\_bias) + \sigma(i)\odot \tanh(j)} \\ &{\tt new\_h = \tanh(new\_c) \odot \sigma(o)} \end{aligned} ==:=c,h=stateσ([inputs,h]×_kernel+_bias)σ([x(t)h(t1)][UiWiUcWcUfWfUoWo]+[bibcbfbo])[i(t)c~(t)f(t)o(t)][i,j,f,o]new_c=cσ(f+forget_bias)+σ(i)tanh(j)new_h=tanh(new_c)σ(o)
new_h作为LSTM的输出传给下一层,同时new_cnew_h打包成一个特别类型的元组(LSTMStateTuple类型,该类型会检查new_cnew_h的数据类型是否相同)作为LSTM传递给下一个时间步cell的状态。老的代码中允许将new_cnew_h组合成一个数组,但是会减慢速度,不建议这么做

LSTMCell

上一小节所讲述的BasicLSTMCell除了对遗忘门结果加偏置以外,并没有添加更多功能,可以看做是一个“基线模型”。LSTMCellBasicLSTMCell的基础上增加了一些新的功能

  • 输出裁剪。初始化对象时若指定了proj_clip参数,则 c ( t ) \boldsymbol{c}^{(t)} c(t) h ( t ) \boldsymbol{h}^{(t)} h(t)每个维度元素都会被裁剪到[-proj_clip, proj_clip]这个区间(小于-proj_clip则置为-proj_clip,大于proj_clip则置为proj_clip,中间元素不变)
  • 输出投影。初始化对象时若指定了num_proj参数,则会额外学习一个线性变换,将输出从num_units投影到num_proj维上(如果同时指定了裁剪和投影,则先投影再裁剪)
  • Peephole LSTM[Gers2002],是想在前一个时间步输出门接近0时,仍然“偷看”一些该时间步单元的状态。实现中引入了三个权重向量 w i d i a g \boldsymbol{w}_i^{\rm diag} widiag w f d i a g \boldsymbol{w}_f^{\rm diag} wfdiag w o d i a g \boldsymbol{w}_o^{\rm diag} wodiag,将其与 c ( t − 1 ) \boldsymbol{c}^{(t-1)} c(t1)做逐元素相乘(等价于引入三个对角矩阵,将其与 c ( t − 1 ) \boldsymbol{c}^{(t-1)} c(t1)相乘)。因此当初始化LSTMCell传入的参数use_peepholes = True时,LSTM计算 c ( t ) \boldsymbol{c}^{(t)} c(t) h ( t ) \boldsymbol{h}^{(t)} h(t)的逻辑修改如下

c ( t ) = σ ( f ( t ) + b f o r g e t + w f d i a g ⊙ c ( t − 1 ) ) ⊙ c ( t − 1 ) + σ ( i ( t ) + w i d i a g ⊙ c ( t − 1 ) ) ⊙ c ~ ( t ) h ( t ) = σ ( o ( t ) + w o d i a g ⊙ c ( t − 1 ) ) ⊙ tanh ⁡ ( c ( t ) ) \begin{aligned} \boldsymbol{c}^{(t)} &= \sigma\left(\boldsymbol{f}^{(t)} + \boldsymbol{b}_{\rm forget} + \boldsymbol{w}_f^{\rm diag}\odot \boldsymbol{c}^{(t-1)} \right) \odot \boldsymbol{c}^{(t-1)} + \sigma\left(\boldsymbol{i}^{(t)} +\boldsymbol{w}_i^{\rm diag}\odot \boldsymbol{c}^{(t-1)} \right) \odot \tilde{\boldsymbol{c}}^{(t)} \\ \boldsymbol{h}^{(t)} &= \sigma\left(\boldsymbol{o}^{(t)} + \boldsymbol{w}_o^{\rm diag}\odot \boldsymbol{c}^{(t-1)} \right) \odot \tanh\left(\boldsymbol{c}^{(t)}\right) \end{aligned} c(t)h(t)=σ(f(t)+bforget+wfdiagc(t1))c(t1)+σ(i(t)+widiagc(t1))c~(t)=σ(o(t)+wodiagc(t1))tanh(c(t))

MultiRNNCell

MultiRNNCell是多个RNNCell堆叠而成的结果(所谓的“多层RNN”),因此初始化时需要将一个RNNCell类型对象的列表传入。需要注意的是,列表中id相同的cell共享权重,因此正确的初始化方式应该如下

# 注意以下两行构建cells数组的操作是错的,这样会使cells中所有元素共享权重
# cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_num_units)
# cells = [cell for _ in range(rnn_layers)
cells = [tf.nn.rnn_cell.BasicLSTMCell(lstm_num_units) for _ in range(num_rnn_layers)]
stacked_cell = tf.nn.rnn_cell.MultiRNNCell(cells)

其内部对call的实现是前一层cell的输出作为后一层cell的输入,然后每一层向下游传递的状态收集打包成一个元组。这样的封装使得堆叠的RNN cell与单层RNN cell行为相同,都可以看作是接受一个输入和一组上游传递来的状态,给出最终的输出并将新的状态传递给下游

下图给出了MultiRNNCell的示意图(为了简单起见没有画出当前第t时间步向后传递状态的动作)。该图的大致思路来自于Towser: 学会区分RNN的output和state。该文章讲得更加详细,可以参考

神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览_第4张图片
对应MultiRNNCellcall函数实现逻辑可大致整理如下

def call(inputs, state):
    current_input, new_states = inputs, []
    # self._cells就是初始化时传入的cells列表
    # 对象初始化时,默认以元组方式接收到上游状态,并以同样方式传递给下游
    # 实现还提供了另一种将所有状态组合成tensor的操作,但是比较复杂,之后也会废弃
    # 因此这里为了简化逻辑突出重点,不考虑此种方式
    for i, cell in enumerate(self._cells):
        current_state = state[i]
        # 再次强调对LSTM,new_state本身也是一个元组(c, h)
        current_input, new_state = cell(current_input, current_state)
        new_states.append(new_state)
    return current_input, tuple(new_states)

RNN的实现

前述所有RNN cell的实现实际上都只是说明在某个时间步RNN单元的计算逻辑,接下来自然要考虑的问题是如何让RNN的信息按照时序流动起来。TensorFlow为此提供了两种机制,一种是静态机制,另一种是动态机制

tensorflow/python/ops/rnn.py里定义的函数中,函数名以static开头的函数都是使用静态机制,函数名包含dynamic的函数都是使用动态机制。TF2对这两者做了合并:以最基础的static_rnndynamic_rnn为例,在TF2中都是keras.layers.RNN,唯一不同的是如果要用static_rnn在创建对象时要指定unroll=True——这就正好点出了静态机制和动态机制最大的不同:静态RNN在建图时直接展开节点,而动态RNN则是根据传入batch的步数,使用循环迭代计算。因此静态RNN必须保证所有batch的最大长度都一样,而动态RNN则只需要保证每个batch内数据长度相等就可以。因此,如果静态RNN设的时间步特别大,以至于大部分样本的实际长度都远小于这个阈值,那么大部分输入就会被补很多0,无论是运算时的性能,还是建图时的开销都会很大。如果要避免这种无谓开销,只能使用分桶的方法,即创建若干个图,每个图(对应一个桶)展开长度不同,图与图之间共享参数——当然动态RNN的每个batch中也存在补齐问题,但是比静态RNN简单得多,可以看做每个batch都是一个分桶。因此,一般都建议使用动态机制

静态机制

有三个使用静态机制的函数,分别为static_rnnstatic_state_saving_rnnstatic_bidirectional_rnn。由于前文所述一般使用动态机制,因此这里只大致介绍这三个函数里最基础的static_rnn。该函数实现比较简单:如果参数传入时没有指定各句长度,直接按照时间步依次调用rnn_cell就可以;否则调用_rnn_loop这个参数做”动态“展开(注意动态机制也会调用这个函数,区别在传的一个参数不同)。大致逻辑如下

def static_rnn(cell,
               inputs,   # 这里是一个长度为T,每个元素形状为[B, D]的数组
               ...):
    if sequence_length is not None:
        min_sequence_length = math_ops.reduce_min(sequence_length)
        max_sequence_length = math_ops.reduce_max(sequence_length)
    state = initial_state if initial_state is not None else cell.zero_state()
    outputs = []
    for t, _input in enumerate(inputs):
        call_cell = lambda: cell(input_, state)
        if sequence_length is not None:
            (output, state) = _rnn_step(time=t, ...)
        else:
            (output, state) = call_cell()
        outputs.append(output)
    return (outputs, state)

def _rnn_step(time,
              sequence_length,
              min_sequence_length,
              max_sequence_length,
              zero_output,
              state,
              call_cell,
              state_size,
              skip_conditionals=False):
    copy_cond = time >= sequence_length
    if skip_conditionals:
        # 动态RNN使用的逻辑,效率更高
        new_output, new_state = call_cell()
        output_and_state = array_ops.where(copy_cond, zero_output, new_output) + \
                           array_ops.where(copy_cond, state, new_state)
    else:
        # 静态RNN使用的逻辑。实际使用control_flow_ops.cond实现条件选择
        if time >= max_sequence_length:
            output_and_state = zero_output + state
        else:
            new_output, new_state = call_cell()
            if time < min_sequence_length:
                output_and_state = new_output + new_state
            else:
                output_and_state = array_ops.where(copy_cond, zero_output, new_output) + \
                                   array_ops.where(copy_cond, state, new_state)
    output, state = split(output_and_state)
    return output, state

动态机制

动态机制的基石函数dynamic_rnn,其逻辑实际是蕴含在了_dynamic_rnn_loop这个私有函数中,暴露的公开接口只是再额外做一些数据转换和检查。尽管在默认情况下输入是batch major的,即输入inputs的第一维是batch维(输入形状是B x T x D),但是内部处理实际上是time major的,即要转置成T x B x D的形状(不过返回时会再转置回来)

_dynamic_rnn_loop的逻辑大致为

def _dynamic_rnn_loop(cell, inputs, initial_state, 
                      sequence_length=None, dtype=None):
    state = initial_state
    time_steps = shape(inputs)[0]
    if sequence_length is not None:
        min_sequence_length = math_ops.reduce_min(sequence_length)
        max_sequence_length = math_ops.reduce_max(sequence_length)
    else:
        max_sequence_length = time_steps
    time = 0
    loop_bound = min(time_steps, max(1, max_sequence_length))
    final_outputs = []
    # 实际使用了tf的control_flow_ops.while_loop来控制
    while time < loop_bound:
        # 下一条“语句”实际上涉及了比较复杂的TensorArray对象操作
        # 但是由于tf1.x调试比较难,因此我也没太明白具体细节,只是推测
        input_t = inputs[time]
        call_cell = lambda: cell(input_t, state)
        if sequence_length is not None:
            (output, new_state) = _rnn_step(
                    time=time,
                    sequence_length=sequence_length,
                    min_sequence_length=min_sequence_length,
                    max_sequence_length=max_sequence_length,
                    zero_output=zero_output,
                    state=state,
                    call_cell=call_cell,
                    state_size=state_size,
                    skip_conditionals=True)
        else:
            (output, new_state) = call_cell()
        state = new_state
        # 这条“语句”也涉及比较复杂的TensorArray对象操作
        final_outputs.append(output)
        time += 1
    return final_outputs, state

dynamic_rnn实现的是RNN的单向传播。但是在很多问题中,对每个独立的时间步,不仅需要看上游状态,也可以看到下游状态,将两者结合,效果会更好。因此TF还封装实现了bidirectional_dynamic_rnn函数,内部对两次调用dynamic_rnn的结果(即两个方向的结果)做了组合。其逻辑为

def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, ...)
    outputs_fw, state_fw = dynamic_rnn(cell_fw, inputs, ...)
    outputs_bw, state_bw = dynamic_rnn(cell_bw, reverse(inputs), ...)
    outputs = (outputs_fw, reverse(outputs_bw))
    states = (state_fw, state_bw)
    return (outputs, states)

这里要注意两点

  • 最终的输出实际上是两个RNN输出的拼接,所以维度比指定的大了一倍
  • 在实现多层RNN的时候,cell_fwcell_bw都是MultiRNNCell的实例。但是前面提到,这个类封装了各层之间数据传递的细节,因此每一层前向RNN和后向RNN都没有交互,实际上只是训练了两个网络,把他们拼起来而已

tf.contrib.rnn.stack_bidirectional_dynamic_rnn实现了前向RNN和后向RNN各层交互的功能,具体逻辑为

def stack_bidirectional_dynamic_rnn(cells_fw, cells_bw, inputs, ...):
    states_fw, states_bw = [], []
    prev_layer = inputs
    for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
        outputs, (state_fw, state_bw) = bidirectional_dynamic_rnn(cell_fw, cell_bw, prev_layer, ...)
        prev_layer = array_ops.concat(outputs, 2)
        states_fw.append(state_fw)
        states_bw.append(state_bw)
    return prev_layer, tuple(states_fw), tuple(states_bw)

最后,TF暴露了一个更底层的API raw_rnn,允许开发人员以更大自由度处理输入和沿时间轴展开的逻辑,例如使第T个时间步的输入为第T-1个时间步的输出。这里直接贴上文档中给出的伪代码

def raw_rnn(cell, loop_fn, ...)
    time = tf.constant(0, dtype=tf.int32)
    (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
            time=time, cell_output=None, cell_state=None, loop_state=None)
    emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
    state = initial_state
    while not all(finished):
        (output, cell_state) = cell(next_input, state)
        (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
                time=time + 1, cell_output=output, cell_state=cell_state,
                loop_state=loop_state)
        # Emit zeros and copy forward state for minibatch entries that are finished.
        state = tf.where(finished, state, next_state)
        emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
        emit_ta = emit_ta.write(time, emit)
        # If any new minibatch entries are marked as finished, mark these.
        finished = tf.logical_or(finished, next_finished)
        time += 1
    return (emit_ta, state, loop_state)

参考文献

[Gers2002]: Gers, F. A., Schraudolph, N. N., & Schmidhuber, J. (2002). Learning precise timing with LSTM recurrent networks. Journal of machine learning research, 3(Aug), 115-143.

你可能感兴趣的:(神经翻译笔记)