Transformer decoder中masked attention的理解

前前后后看了挺久的Transformer,本以为自己理解了,可实现起来总觉得差点意思。

Transformer decoder中masked attention的理解_第1张图片

encoder比较简单,不多介绍。记录一下decoder的使用。

masked attention是要使用一个上三角矩阵torch.triu来实现对未来信息的掩盖。为什么就掩盖未来信息了?看了这篇博客,明白了但没完全明白,说是decoder在训练时用的groundtruth,防止误差累积,取得比较好的训练结果,但像下图中这样也没发现decoder的输入中有未来信息啊。

Transformer decoder中masked attention的理解_第2张图片

 在实践中,发现在训练时,如果要实现这个任务,其实做的是在decoder输入BOS,11,12,13,21,22,而groundtruth使用11,12,13,21,22,EOS,就是说输出要比输入往右错开一位训练即可,其实这也是很多NLP中seq2seq模型的训练方法。这样的话,在训练时,就要考虑到未来信息的泄露问题了。加入masked attention后,transformer的decoder在功能上其实相当于rnn了,当前输出只与当前和过去输入有关,而与未来信息无关,二者区别在于,rnn的历史信息,只能一级一级的传递到当前时间步,而decoder直接使用attention,可以直接实现信息传递,比如,t-4时刻的信息,rnn只能传播4次才能到t时刻,attention只需要传播一次即可。

你可能感兴趣的:(torch,API,transformer,深度学习,人工智能)