稀疏Attention

1. 模型

Self Attention是 O ( n 2 ) O(n^2) O(n2)的,那是因为它要对序列中的任意两个向量都要计算相关度,得到一个 n 2 n^2 n2大小的相关度矩阵:

稀疏Attention_第1张图片

左边显示了注意力矩阵,右变显示了关联性,这表明每个元素都跟序列内所有元素有关联。

所以,如果要节省显存,加快计算速度,那么一个基本的思路就是减少关联性的计算,也就是认为每个元素只跟序列内的一部分元素相关,这就是稀疏Attention的基本原理。

2. 稀疏Attention

Atrous Self Attention

膨胀注意力:
Atrous Self Attention就是启发于“膨胀卷积(Atrous Convolution)”,如下右图所示,它对相关性进行了约束,强行要求每个元素只跟它相对距离为k,2k,3k,…的元素关联,其中k>1是预先设定的超参数。

稀疏Attention_第2张图片

由于现在计算注意力是“跳着”来了,所以实际上每个元素只跟大约 n / k n/k n/k个元素算相关性,这样一来理想情况下运行效率和显存占用都变成了 O ( n 2 / k ) O(n^2/k) O(n2/k),也就是说能直接降低到原来的 1 / k 1/k 1/k

Local Self Attention
另一个要引入的过渡概念是Local Self Attention,中文可称之为“局部自注意力”。其实自注意力机制在CV领域统称为“Non Local”,而显然Local Self Attention则要放弃全局关联,重新引入局部关联。具体来说也很简单,就是约束每个元素只与前后 k k k个元素以及自身有关联,如下图所示:
稀疏Attention_第3张图片

Sparse Self Attention
Atrous Self Attention是带有一些洞的,而Local Self Attention正好填补了这些洞,所以一个简单的方式就是将Local Self Attention和Atrous Self Attention交替使用,两者累积起来,理论上也可以学习到全局关联性,也省了显存。

稀疏Attention_第4张图片


参考:
苏blog

你可能感兴趣的:(算法)