Google FLASH-QUAD Transformer模型的设计雷点

这个模型用来做无序弱监督分类,效果好,特别是收敛速度比标准多头Attention层快多了,完全没得比。

问题1

但这模型我用来做自回归生成,非常垃圾。
同时尝试了 GPT 和 T5 这两种模型结构的设计,明明Loss正常下降,可是自回归生成性能非常的烂,不知原因为何。

不服输,最近再来尝试FLASH,毕竟性能太过于吸引人。碰巧单步调试了一下自回归生成的过程。
卧槽,意外发现cause掩码失效,前一个时间步的输出会被后一个时间步的输入影响,

一步步排查,排查到注意力矩阵的生成
注意到这个 1/n 的 n 是可变的。直接把 n 去掉,使注意力矩阵的值不再受序列长度的缩放。
下图来自苏神的博客
在这里插入图片描述
对应到代码,在 lucidrains 的代码里面 https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L190

sim = einsum('b i d, b j d -> b i j', q, k) / seq_len

我将其改为一个定值

sim = einsum('b i d, b j d -> b i j', q, k) / q.shape[-1]

改为,现在 前一个时间步的输出不再 被后一个时间步的输入影响了。

问题1.1

改为定值后,尚未实验,但预计超出训练长度后(例如最大训练文本长度为512,测试文本长度为768),性能会有显著下降。

问题2

修改完,初步的训练后,自回归生成能力有了大幅的提升了。
但仍然存在问题,这个注意力方法的局部关注能力似乎很弱,意思为经常见到连续生成同义的词
例如(空格代表分词)
标签为

树叶 静静地 燃烧 起来

自回归生成(使用sample策略)多见这样的生成范式(不是必定出现)

树叶 静静地 安静地 燃烧 起来

相近意思的词会有时多生成一次,一般的多头注意力出现这样的情况非常少见,推测该设计的局部关注能力较弱。

类似的讨论

https://github.com/JunnYu/FLASHQuad_pytorch/issues/1

你可能感兴趣的:(深度学习的经验,transformer,深度学习,pytorch,FLASH)