Seq2Seq增加attention机制的原理说明

以中文翻译为英文为例讲解seq2seq的原理,以及增加attention机制之后的seq2seq优化版本。

文本参考:

Pytorch实现Seq2Seq(Attention)字符级机器翻译_pytorch seq2seq_孤独腹地的博客-CSDN博客

https://github.com/datawhalechina/learn-nlp-with-transformers/blob/main/docs/%E7%AF%87%E7%AB%A02-Transformer%E7%9B%B8%E5%85%B3%E5%8E%9F%E7%90%86/2.1-%E5%9B%BE%E8%A7%A3attention.md

一、seq2seq

Seq2Seq增加attention机制的原理说明_第1张图片

步骤:

1、将中文“我是学生”进行分词,分别得到“我”,“是”,“学生”

2、每个单词通过word2vec转化为向量

3、初始化RNN的隐藏层向量为h_init

4、Encode阶段:将h_init和input(“我”的词向量)输入RNN族神经网络,得到h1。再将h1和input(“是”的词向量)输入RNN得到h2。直到得到最后的h3,将最后一个hidden即h3作为context

5、Decode阶段:将context作为decode阶段RNN的初始隐藏层向量,与input(开始的词向量)输入RNN族神经网络,得到h1’。h1’再输入FC网络,再经过softmax得到数据分布,取argmax对应的单词即为预测的第一个单词。然后再将h1’和input(“I”的词向量)输入RNN得到h2’后再得到预测的第二个单词。直到最后预测的单词为结束单词即“”。

存在的问题:一个单词向量很难包含所有文本序列的信息。比如RNN处理到第500个单词的时候,很难再包含1-499个单词中的所有信息了。

于是我们通过增加attention机制解决上述问题。

二、seq2seq增加attention机制

Seq2Seq增加attention机制的原理说明_第2张图片

步骤:

1、将中文“我是学生”进行分词,分别得到“我”,“是”,“学生”

2、每个单词通过word2vec转化为向量

3、初始化RNN的隐藏层向量为h_init

4、Encode阶段:将h_init和input(“我”的词向量)输入RNN族神经网络,得到h1。再将h1和input(“是”的词向量)输入RNN得到h2。直到得到最后的h3。将h1+h2+h3一起作为encode的输出

5、Decode阶段:

(1)将h3作为RNN的hidden_init,与input(的词向量)输入RNN得到输出h1‘。

(2)h1’与encode的输出(h1、h2、h3)分别进行点积得到3个weight值

(3)将weight值再乘以encode的输出(h1、h2、h3)得到最终的context

(4)将context与h1’进行concat作为FC的输入

(5)输入FC网络,再经过softmax得到数据分布,取argmax对应的单词即为预测的第一个单词。

(6)重复以上1~5步直到预测单词为”

以上标红部分为与未加attention的区别部分。

你可能感兴趣的:(神经网络,深度学习,人工智能)