7_Attention(注意力机制)

文章目录

  • 一、Seq2Seq Model
  • 二、Seq2Seq Model with Attention
    • 2.1 SimpleRNN + Attention
      • 2.1.1 权重计算α~i~
      • 2.1.2 Context vector C~i~
    • 2.2 Time Complexity(时间复杂度)
  • 三、Summary(总结)

一、Seq2Seq Model

Shortcoming: The final state is incapable of remembering a long sequence.

7_Attention(注意力机制)_第1张图片

二、Seq2Seq Model with Attention

  • Attention tremendously improves Seq2Seq model.(Attention极大地改善Seq2Seq 模型)
  • With attention, Seq2Seq model does not forget source input.
  • With attention, the decoder knows where to focus.(decoder 更新状态的时候,都会再看一遍encoder所有状态,这样就不会遗忘;attention还可以告诉decoder应该关注encoder哪个状态)
  • Downside: much more computation.(缺点:更多的计算)
  • https://distill.pub/2016/augmented-rnns/

2.1 SimpleRNN + Attention

在Encoder结束工作之后,Decoder 和Attention同时开始工作。7_Attention(注意力机制)_第2张图片

  • encoder的所有状态都要保留下来,这里需要计算S0 与每一个h的相关性。
  • α1,α2,α3,···αm都是介于(0,1)之间的实数,所有α相加等于1。

7_Attention(注意力机制)_第3张图片

2.1.1 权重计算αi

方法一: (used in the original paper):在原始论文中使用

7_Attention(注意力机制)_第4张图片

方法二:(more popular; the same to Transformer)

7_Attention(注意力机制)_第5张图片

2.1.2 Context vector Ci

7_Attention(注意力机制)_第6张图片

7_Attention(注意力机制)_第7张图片

2.2 Time Complexity(时间复杂度)

问题:How many weights ai have been computed? (我们共计算了多少权重ai

  • To compute one vector Ci ,we compute m weights: α1,α2,α3,···αm 。(想要计算出一个 Ci ,我们需要计算m个权重α1,α2,α3,···αm
  • The decoder has t states, so there are totally mt weights.(假设Decoder运行了t步,那么一共计算了 mt 个权重,因此时间复杂度mt)

7_Attention(注意力机制)_第8张图片

三、Summary(总结)

  • Standard Seq2Seq model: the decoder looks at only its current state.(标准的Seq2Seq模型:decoder基于当前状态来产生下一个状态,这样产生的新状态可能忘记了encoder的部分输入)

  • Attention: decoder additionally looks at all the states of the encoder.(注意力机制:decoder 在产生下一个状态之前,会先看一遍encoder的所有状态,于是decoder会知道encoder的完整信息,并不会遗忘)

  • Attention: decoder knows where to focus.(注意力机制:除了解决遗忘的信息,attention还会告诉decoder应该关注encoder的哪一个状态)

  • Downside: higher time complexity.(缺点:计算量太大)

    • m: source sequence length (假设输入encoder的序列长度为m)
    • t: target sequence length (decoder输出的序列长度为t)
    • Standard Seq2Seq: O(m + t ) time complexity (标准的Seq2Seq:只需要让encoder读一遍输入序列,之后不再看encoder的输入,然后让decoder依次生成输出序列)
    • Seq2Seq + attention: O(mt) time complexity(decoder每一次更新状态,都要把encoder的m个状态看一遍,所以每次的时间复杂度为m,decoder自己有t个状态,因此总时间复杂度是mt。

    使用attention可以提高准确率,要付出更多的计算。

你可能感兴趣的:(NLP,深度学习,机器学习,pytorch)