attention 理解 根据pytorch教程seq2seq源码

https://blog.csdn.net/wuzqchom/article/details/75792501

http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc

pytorch源码

这是李宏毅老师的ppt。右侧对应pytorch seq2seq源码。

我们的问题是,左边的数学符号,右侧的代码是如何对应的?

attention 理解 根据pytorch教程seq2seq源码_第1张图片attention 理解 根据pytorch教程seq2seq源码_第2张图片

 

 

 

1,不是词嵌入,而是编码器的输出,如   源码中的output。attention 理解 根据pytorch教程seq2seq源码_第3张图片

为什么是输出而不是隐藏呢?这要从之后的函数中看出。

attention 理解 根据pytorch教程seq2seq源码_第4张图片

训练函数中设置了一个大的,全是零的encoder_outputs的矩阵,红线部分将encoder_output存储起来,而隐藏只是在不断的循环。从PPT可以看出来,每次是需要全部的H1,H2,H3, h4 ........,那么肯定使用了encoder_outputs这个大大的矩阵。故是输出对应,而不是隐藏。

其次注意,这里的GRU,SEQ长度只是1.它的序列的扩展是通过for函数的用于循环,依次遍历每个单词,来进行序列方向上的扩展。

 

 

2,李宏毅老师匹配函数,在源码中是怎么实现的回答:是通过定义的一层神经网络来实现的。

attention 理解 根据pytorch教程seq2seq源码_第5张图片

attention 理解 根据pytorch教程seq2seq源码_第6张图片

可以看出来,解码器有个self.attn的线性层,这个线性层就是我们要找的匹配函数。为什么呢?看attendecoderRNN的前进中,拼接两个向量,再进行线性层,且函数名是attn_weights 。正好对应的上面绿色箭头的* 2 

所以,这里的attn_weights就是

那么,为什么要使用神经网络进行匹配呢?

但是实际上,有多种匹配的方法,例如:直接相乘(dot),再多一个中间W矩阵(general),最后也是本教程使用的方法,使用神经网络实现。你也可以使用其他的方法!!!并不需要局限

attention 理解 根据pytorch教程seq2seq源码_第7张图片

3,又对应什么呢?

答,对应代码是

torch.bmm是batch的乘法操作,即1 * 1 * 10与1 * 10 * 256的矩阵会变成1 * 1 * 256

 

 

4,Z0是什么呢?答

Z0是对应的嵌入向量。当然,如果嵌入是随机初始化的话,那么Z0确实是随机的,因此哔哩哔哩弹幕上说Z0是随机的也是正确的〜

之后的Z1,Z2实际上都是对应的,译码器的隐层状态。我们是使用Ž这个隐藏层的状态,来对encoder_outputs矩阵进行匹配打分的。

attention 理解 根据pytorch教程seq2seq源码_第8张图片

为什么呢?依旧从源码看出来

attention 理解 根据pytorch教程seq2seq源码_第9张图片

在用于循环第一遍输入的时候,就将decoder_hidden送入其中。对应解码器的输入参数

而decoder_hidden又是编码器最后一个状态输出。所以李宏毅老师说的initial_memory,我认为就是编码器最后一个隐藏状态。

 

 

5,Z1是怎么计算得到呢??

回答是attn_weight与输入的德文单词的词向量相乘后的结果。

attention 理解 根据pytorch教程seq2seq源码_第10张图片

比如,C0是attn_applied[0],是每次“注意”的向量。

Z0是刚开始的符号,当然Z1就是第一个输出 machine 对应的隐藏层。。。

注,转换成嵌入向量与attn_weight进行操作对应的代码是这一行:

 

 

6,那么PPT上的输出翻译后的单词对应代码哪一块呢?

这个箭头,对应的

attention 理解 根据pytorch教程seq2seq源码_第11张图片这个箭头,对应的

因为使用了GRU。:)

以上只是个人理解,请指出错误

你可能感兴趣的:(人工智能)