FlashAttention

一、 论文题目(发表处-时间)

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

二、主要方向

新型注意力机制

三、细化任务

一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法

四、论文动机

一般对transformer中关键模块 self-attention进行速度优化,一般使用稀疏近似(通过镂空,其他越远越空的attention方法)低秩分解(将attn 矩阵)矩阵分解两个矩阵后计算。

但都没有节省GPU对内存访问,主要在内存保存注意力矩阵的部分进行改善。

五、论文中的主要贡献点

  1. softmax 处减少HBM

FlashAttention如何实现在不访问整个输入的情况计算softmax大的缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和V做矩阵乘之前,需要让 Q、K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减

  1. 反向传播 减少访问HBM

在后向传播中不存储中间注意力矩阵,以Flash Attention所提供的算法为例,通过对比标准Attention算法在实现过程中,标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。

在Flash Attention的前向计算算法中我们可以看出,Flash Attention算法并没有将S、P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。

六、模型图/模型主框架描述

FlashAttention_第1张图片

七、除了主贡献之外的亮点

  1. 计算softmax时候不需要全量input数据,可以分段计算

  2. 反向传播的时候,不存储attention matrix (N^2的矩阵),而是只存储softmax归一化的系数。

八、数据集及效果

2.效果:

FlashAttention_第2张图片

九、开源代码,开源数据集

十、论文/模型的缺陷、可以改进的点

  1. 应该和稀疏近似,低秩分解并不冲突,将他们融合进一步实现attention加速,以及长度外推。

  2. 在其他位置GPU io处进行一些调整。

十一、论文附件

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

你可能感兴趣的:(人工智能)