Tensorflow中注意力机制的实现:AttentionCellWrapper

文章目录

        • 背景知识
        • AttentionCellWrapper理论基础
        • AttentionCellWrapper源码解析

背景知识

注意力机制最早被用于机器翻译领域,其本质类似于人类在认知事物时的注意力,后因其有效性被广泛用于计算机视觉、语音识别、序列预测等领域。
常见的注意力机制通常是基于Encoder-Decoder的,模型在Decoder阶段进行解码时会考虑编码阶段Encoder的所有隐藏状态。

AttentionCellWrapper理论基础

在Tensorflow中也有现成的注意力API可以使用,即AttentionCellWrapper,具体的实现代码是在tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py文件中。

值得注意的是,Tensoflow中AttentionCellWrapper的实现并不是基于Encoder-Decoder形式的,而是受启发于https://magenta.tensorflow.org/2016/07/15/lookback-rnn-attention-rnn这篇文章中的AttentionRNN。

这篇文章提出了一种单向RNN就能使用的Attention结构(这里我们称为AttentionRNN),在处理每一步的输入时,考虑前面N步的输出,经过映射加权后把这些历史信息加到本次输入的预测中。

In our version, where we don’t have an encoder-decoder, we just always look at the outputs from the last n steps when generating the output for the current step. The way we “look at” these steps is with an attention mechanism.

具体公式如下:
Tensorflow中注意力机制的实现:AttentionCellWrapper_第1张图片
其中:

  • 矩阵W1、W2和向量v均为可学习的参数
  • hi为前面第i步输出的隐藏状态 ct为当前时刻的细胞状态
  • ui为长度为n的相关系数向量,对于前n个step每个step对应一个相关系数。
  • ai为注意力得分,可通过对相关系数进行softmax操作得到,文章中称ai为attention mask。
  • h’t为当前时刻经过attention后的输出,可通过对前n个step的隐藏状态以及对应的注意力的分加权求和得到。

Tensorflow中注意力机制的实现:AttentionCellWrapper_第2张图片

AttentionCellWrapper源码解析

class AttentionWrapper(rnn_cell_impl.RNNCell):
  def __init__(self,
               cell,
               attention_mechanism,
               attention_layer_size=None,
               alignment_history=False,
               cell_input_fn=None,
               output_attention=True,
               initial_cell_state=None,
               name=None,
               attention_layer=None):
  • cell: 被包裹的RNNCell实例;
  • attention_mechanism: attention机制实例,例如BahdanauAttention,也可以是多个attention实例组成的列表;
  • attention_layer_size: 是数字或者数字做成的列表,如果是 None(默认),直接使用加权求和得到的上下文向量 [公式] 作为输出(详见本小节最后的_compute_attention代码),如果不是None,那么将 [公式] 和cell的输出 cell_output进行concat并做线性变换(输出维度为attention_layer_size)再输出。
    这里所说的"输出"在代码里是用"attention"表示的,见本小节最后的_compute_attention函数代码。
  • alignment_history: 即是否将之前的alignments存储到 state 中,以便于后期进行可视化展示,默认False,一般设置为True。
  • cell_input_fn: 怎样处理输入。默认会将上一步的得到的输出与的实际输入进行concat操作作为输入。

参考资料:
Tensorflow中的AttentionCellWrapper:一种更通用的Attention机制
TensorFlow AttentionWrapper源码超详细图解
AttentionCellWrapper

你可能感兴趣的:(Tensorflow)