FlashAttention

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

https://paperswithcode.com/paper/flashattention-fast-and-memory-efficient

https://github.com/HazyResearch/flash-attention

https://arxiv.org/abs/2205.14135                

(才开源了4天就收获了406星)

27 May 2022 · Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré ·  Edit social preview

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3× speedup on GPT-2 (seq. length 1K), and 2.4× speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).

摘要:由于自注意的时间和内存复杂性在序列长度上是二次的,所以Transformer在长序列上速度慢且内存不足。近似注意方法试图通过权衡模型质量来解决此问题,以降低计算复杂性,但通常无法实现挂钟加速。我们认为,一个缺失的原则是让注意力算法具有IO意识——考虑GPU内存级别之间的读写。我们提出了FlashAttention,这是一种IO感知的精确注意算法,它使用平铺来减少GPU高带宽内存(HBM)和GPU片上SRAM之间的内存读写次数。我们分析了FlashAttention的IO复杂性,表明它比标准attention需要更少的HBM访问,并且对于各种SRAM大小都是最优的。我们还将FlashAttention扩展到分块稀疏注意,得到了一种比现有近似注意方法更快的近似注意算法。FlashAttention训练Transformer的速度比现有基线快:与MLPerf1.1训练速度记录相比,BERTlarge(序列长度512)的端到端挂钟加速比为15%,GPT-2(序列长度1K)的加速比为3倍,长程arena(序列长度1K-4K)的加速比为2.4倍。FlashAttention和block稀疏FlashAttention可以在Transformers中提供更长的上下文,产生了更高质量的模型(GPT-2的复杂度提高了0.7个点,长文档分类提高了6.4个点)和全新的功能:第一个在Path-X挑战(序列长度16K,准确率61.4%)和Path-256(序列长度64K,准确率63.1%)上实现了优于偶然性能的Transformer。

1简介

Transformer模型[79]已成为自然语言处理和图像分类等应用中使用最广泛的架构。Transformer已经变得越来越大,越来越深,但给它们配备更长的上下文仍然很困难,因为它们的核心自注意模块的时间和内存复杂性是序列长度的二次方。一个重要的问题是,让注意力更快、内存效率更高是否有助于Transformer模型解决长序列的运行时和内存挑战。

许多近似注意方法旨在减少注意的计算和内存需求。这些方法的范围从稀疏近似(49、71)到低阶近似(11、48、81)及其组合(3、8、89)。虽然这些方法将计算要求降低到线性或接近线性的序列长度,但其中许多方法并没有显示相对于标准注意的挂钟加速,也没有得到广泛采用。一个主要原因是,他们专注于降低FLOPs(这可能与挂钟速度无关),并倾向于忽略内存访问(IO)的开销。

在本文中,我们认为一个缺失的原则是使注意力算法具有IO意识[1],即仔细考虑对不同级别的快速和慢速内存的读写(例如,在快速GPU片上SRAM和相对较慢的GPU高带宽内存之间,或HBM[43],图1左)。在现代GPU上,计算速度超过了内存速度[58、59、60],Transformer中的大多数操作都受到内存访问的限制[41]。IO感知算法对于类似的内存限制操作至关重要,因为读写数据可能会占用运行时的很大一部分,如数据库连接[68]、图像处理[67]、数值线性代数[4]等等[38、82]。然而,PyTorch和Tensorflow等深入学习的常见Python接口不允许对内存访问进行细粒度控制。

我们提出了FlashAttention,这是一种新的注意算法,可以用更少的内存访问来计算精确的注意。我们的主要目标是避免读写HBM的注意力矩阵。这需要(i)在不访问整个输入的情况下计算softmax约简(ii)不存储向后传递的大型中间注意矩阵。我们采用两种成熟的技术来应对这些挑战。(i) 我们重新构造注意力计算,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减(也称为平铺)。(ii)我们存储前向传递的softmax归一化因子,以便在后向传递中快速重新计算芯片上的注意,这比从HBM读取中间注意矩阵的标准方法更快。我们在CUDA中实现FlashAttention,以实现对内存访问的细粒度控制,并将所有注意操作融合到一个GPU核中。即使由于重新计算导致的FLOPs增加,我们的算法运行速度更快(在GPT-2上高达7.6倍[64],图1右),并且由于HBM访问量的大幅减少,在序列长度上使用的内存线性比标准注意更少。

我们分析了FlashAttention的IO复杂性,证明它需要( 2. 2.−1)HBM访问,其中 是头部尺寸和 是SRAM的大小,与Ω相比( + 2)标准注意事项。对于 和, 与标准注意相比,FlashAttention需要的HBM访问量要少很多倍(最多少9倍,如图2所示)。此外,我们还提供了一个下限,表明在所有SRAM大小的情况下,没有精确的注意算法能够渐进地改善HBM访问的数量。

我们还表明,FlashAttention可以作为一个有用的原语,通过克服近似注意算法在内存访问开销方面的问题来实现其潜力。作为概念证明,我们实现了块稀疏FlashAttention,这是一种比FlashAttention快2-4倍的稀疏注意算法,可扩展到64k的序列长度。我们证明了块稀疏FlashAttention比FlashAttention具有更好的IO复杂性,其IO复杂性与稀疏率成正比。我们在第5节中讨论了对其他操作的进一步扩展(注意多GPU、核回归、块稀疏矩阵乘法)。我们开放了源代码FlashAttention,以便更容易在这个原语的基础上进行构建。1.

我们实证验证了FlashAttention通过建模更长的上下文来加速模型训练并提高模型质量。我们还对FlashAttention的运行时和内存占用进行了基准测试,并将稀疏FlashAttention与之前的注意实现进行了比较。

•更快的模型训练。FlashAttention以更快的时间训练Transformer型号。我们训练的BERT large(序列长度512)比MLPerf 1.1[56]中的训练速度记录快15%,GPT2(序列长度1K)比HuggingFace[84]和Megatron LM[74]的基线实现快3倍,长程arena(序列长度1K-4K)比基线快2.4倍。

•更高质量的型号。FlashAttention将Transformer扩展到更长的序列,从而提高其质量并实现新功能。我们观察到,在GPT-2上的困惑度提高了0.7,在长文档分类上对长序列建模提高了6.4点[12]。FlashAttention使第一台Transformer能够在Path-X挑战中获得比chance更好的性能[77],仅通过使用更长的序列长度(16K)。块稀疏FlashAttention使Transformer能够扩展到更长的序列(64K),从而产生第一个在Path-256上性能优于chance的模型。

•注意基准。在基准测试中,FlashAttention在128到2K的常见序列长度上比标准注意实现快3倍,扩展到64K。当序列长度达到512时,FlashAttention比任何现有的注意方法都更快、内存效率更高,而当序列长度超过1K时,一些近似的注意方法(例如Linformer)开始变得更快。另一方面,块稀疏FlashAttention比我们已知的所有现有近似注意方法都要快。

2背景

我们提供了一些关于现代硬件(GPU)上常见深度学习操作的性能特征的背景知识。我们还描述了注意力的标准实现。

2.1硬件性能

我们在这里注意GPU。在其他硬件加速器上的性能类似[44,46]。

GPU内存层次结构。

GPU内存层次结构(图1左)包括不同大小和速度的多种形式的内存,内存越小速度越快。例如,A100 GPU有40-80GB的高带宽内存(HBM),带宽为1.5-2.0TB/s,每108个流式多处理器有192KB的片上SRAM,估计带宽约为19TB/s[42,43]。片上SRAM比HBM快一个数量级,但尺寸小很多数量级。随着计算速度相对于内存速度的提高[58、59、60],操作越来越受到内存(HBM)访问的限制。因此,开发快速SRAM变得更加重要。

执行模型。

GPU有大量线程来执行操作(称为核)。每个核将HBM的输入加载到寄存器和SRAM,计算输出,然后将输出写入HBM。

性能特征。

根据计算和内存访问的平衡,操作可以分为计算绑定或内存绑定。这通常由算术强度来衡量,即每字节内存访问的算术运算数。

1、计算界限:运算所花费的时间取决于有多少算术运算,而访问HBM的时间要小得多。典型的例子是具有大内维数的矩阵乘法和具有大量通道的卷积。

2、内存限制:操作所花费的时间由内存访问次数决定,而计算所花费的时间要小得多。示例包括大多数其他操作:elementwise(例如激活、退出)和reduction(例如sum、softmax、batch norm、layer norm)。核融合。加速内存受限操作的最常用方法是核融合:如果对同一输入应用了多个操作,则可以从HBM加载一次输入,而不是每次操作加载多次。编译器可以自动融合许多元素操作[51、62、72]。

然而,在模型训练的背景下,仍然需要将中间值写入HBM以保存向后传递,从而降低了原始核融合的有效性。

2.2标准注意执行

给定输入序列Q、K、V∈ R × 哪里 是序列长度和 是头部尺寸,我们要计算注意力输出O∈ R ×

其中,按行应用softmax。

标准注意实现将矩阵S和P具体化为HBM,这需要( 2)内存。经常 >> (例如,对于GPT2, = 1024和 = 64). 我们描述了算法0中的标准注意实现。由于部分或大部分操作是内存受限的(例如softmax),大量内存访问会导致挂钟时间变慢。

这一问题因应用于注意矩阵的其他元素操作而加剧,例如应用于S的掩蔽或应用于P的dropout。因此,有许多尝试融合多个元素操作,例如将掩蔽与softmax融合[74]。

在第3.2节中,我们将展示标准注意实现在序列长度上执行二次HBM访问. 我们还比较了标准注意和我们的方法(FlashAttention)的FLOPs和HBM访问次数。

3 FlashAttention:算法、分析和扩展

我们展示了如何使用较少的HBM读/写操作计算精确注意,并且不存储用于向后传递的大型中间矩阵。这就产生了一种注意力算法,该算法既节省内存,又在挂钟时间内更快。我们分析了它的IO复杂性,表明与标准注意相比,我们的方法需要更少的HBM访问。我们进一步表明,FlashAttention可以作为一个有用的原语,通过扩展它来处理块稀疏注意。

我们在这里重点介绍向前传球,以便于展示;附录B包含了向后的详细信息。

3.1具有平铺和重新计算功能的高效注意算法

给定输入Q、K、V∈ R × 在HBM中,我们的目标是计算注意力输出O∈ R × 并将其写入HBM。我们的目标是减少HBM访问量(在).

我们应用两种已建立的技术(平铺、重新计算)来克服在次二次HBM访问中计算精确注意的技术挑战。我们在算法1中描述了这一点。其主要思想是,我们将输入Q、K、V分成块,将它们从慢速HBM加载到快速SRAM,然后计算这些块的注意输出。在将每个块的输出相加之前,通过正确的归一化因子缩放每个块的输出,我们最终得到了正确的结果。

平铺。

我们按块计算注意力。Softmax耦合K列,因此我们使用缩放分解大型Softmax[49,63]。对于数值稳定性,向量的softmax ∈ R 计算如下:

对于向量 (1) , (2) ∈ R, 我们可以分解级联的softmax =

(1) (2)

∈ R 2 作为:

因此,如果我们跟踪一些额外的统计数据((), ℓ()), 我们可以一次计算一个块的softmax。2因此,我们将输入Q、K、V拆分为块(算法1第3行),计算softmax值以及额外统计信息(算法1第10行),并合并结果(算法1第12行)。

重新计算。

我们的目标之一是不存储( 2)向反向传播的中间值。向后传递通常需要矩阵S,P∈ R × 计算关于Q、K、V的梯度。然而,通过存储输出O和softmax归一化因子ℓ, 在SRAM中,从Q、K、V块向后传递时,我们可以很容易地重新计算注意矩阵S和P。这可以看作是选择性梯度检查点的一种形式[9,32]。虽然有人建议使用梯度检查点来减少所需的最大内存量[63],但所有实现(我们都知道)都必须以速度换取内存。相反,即使有更多的FLOPs,由于HBM访问减少,我们的重新计算也会加快向后传递(图2)。完整的后向传递描述见附录B。

实现细节:核融合。

平铺使我们能够在一个CUDA核中实现我们的算法,从HBM加载输入,执行所有计算步骤(矩阵乘法、softmax、可选屏蔽和退出、矩阵乘法),然后将结果写回HBM(附录B中的屏蔽和退出)。这避免了重复读取和写入HBM的输入和输出

我们展示了FlashAttention的正确性、运行时和内存需求(证据见附录C)。

定理1。

算法1返回O=softmax(QK>)V( 2.) FLOPs和要求() 输入和输出之外的额外内存

3.2分析:FlashAttention的IO复杂性

我们分析了FlashAttention的IO复杂性,与标准注意相比,HBM访问量显著减少。我们还提供了一个下限,证明了在所有SRAM大小的情况下,没有精确的注意算法能够渐进地改善HBM访问。证据见附录C。

定理2。

允许 为序列长度, 为封头尺寸,以及 具有SRAM的大小 ≤ ≤ . 标准注意(算法0)需要Θ( + 2)HBM访问,而FlashAttention(算法1)需要Θ( 2. 2.−1)HBM访问。

对于 (64-128)和 (约100KB), 2比, 因此,FlashAttention需要的HBM访问比标准实现少很多倍。这将导致更快的执行和更低的内存占用,我们在第4.3节中对此进行了验证。

证明的主要思想是,给定, 我们可以加载大小为Θ的K、V块() 每个(算法1第6行)。对于K和V的每个块,我们迭代Q的所有块(算法1第8行)以计算中间值,得到Θ( −1)通过Q。每次通过加载Θ( ) 元素,总计为Θ( 2. 2.−1)HBM访问。我们同样证明,标准注意力的向后传递需要Θ( + 2)当FlashAttention的向后传递需要Θ时,HBM访问( 2. 2.−1)HBM通道(附录B)。

我们证明了一个下界:对于所有的 (SRAM大小)计算精确注意力时。

提案3。

允许 为序列长度, 为封头尺寸,以及 是快速片上存储器的大小。不存在计算精确注意力的算法( 2. 2.−1)所有HBM接入 在范围内[, ].

证据取决于以下事实: = Θ( ) 任何算法都必须执行Ω( 2. 2.−1 ) = Ω( ) HBM访问。子范围上的此类下界 在流媒体算法文献中很常见[85]。我们留下来证明参数化复杂性的下界 作为令人兴奋的未来工作。

我们验证了HBM访问次数是注意力运行时间的主要决定因素。在图2(左)中,我们可以看到,尽管FlashAttention与标准attention相比具有更高的FLOPs计数(由于在后向过程中重新计算),但它具有更少的HBM访问,从而导致运行时更快。在图2中(中间),我们改变了块大小 这会导致不同数量的HBM访问,并测量前向传递的运行时间。随着块大小的增加,HBM访问的数量减少(因为我们对输入进行的传递更少),运行时减少。对于足够大的块大小(超过256),运行时会受到其他因素(例如算术运算)的制约。此外,较大的块大小不适合较小的SRAM大小。

3.3扩展:分块稀疏FlashAttention

我们将FlashAttention扩展到近似注意:我们提出了块稀疏FlashAttention,其IO复杂度比FlashAttention小一个与稀疏度成比例的因子。

给定输入Q、K、V∈ R × 和掩模矩阵M°∈ {0, 1} × , 我们要计算:

给定预定义的块稀疏掩码M∈ {0, 1} /× / 我们可以很容易地调整算法1,只计算注意矩阵的非零块。该算法与算法1相同,只是跳过了零个块。我们复制了附录B算法5中的算法描述。

我们还分析了块稀疏FlashAttention的IO复杂性

提案4。

允许 为序列长度, 为封头尺寸,以及 具有SRAM的大小 ≤ ≤ . 分块稀疏FlashAttention(算法5)需要Θ( + 2. 2.−1.) HBM访问的位置 是块稀疏遮罩中非零块的分数。

我们发现,应用块稀疏性可以直接提高IO复杂性中较大项的稀疏性。对于较大的序列长度, 通常设置为 −1/2[10]或 −1个日志 [3,16,89],导致( √ ) 或Θ( 日志) IO复杂性。对于下游实验,我们使用固定蝴蝶稀疏模式[16],该模式已被证明能够近似任意稀疏度[15]。

在图2(右)中,我们验证了随着稀疏度的增加,块稀疏FlashAttention的运行时会成比例地提高。在LRA基准上,block sparse FlashAttention实现了2.8倍的加速,同时表现与标准注意力相当(第4节)

4个实验

我们评估了使用FlashAttention训练Transformer模型的影响。我们验证了关于训练时间和模型准确性的两种说法,并报告了注意力运行时和内存基准。

•训练速度。

FlashAttention比BERT的MLPerf 1.1[56]速度记录高出15%,GPT-2的速度比标准Transformer高出3倍,比 Megatron高出1.8倍。FlashAttention加速了长程竞技场(LRA)基准2.4×。

•质量。

FlashAttention将Transformer扩展到更长的序列,从而产生更高的质量。FlashAttention训练的GPT-2上下文长度为4K,比 Megatron训练的GPT-2上下文长度为1K快,同时获得了0.7更好的困惑度。对较长的序列进行建模,可以在两个长文档分类任务上获得6.4个提升点。最后,FlashAttention生成了第一个在具有挑战性的Path-X任务(序列长度16K)上比随机性能更好的Transformer,而block稀疏FlashAttention生成了我们所知的第一个在Path-256(序列长度64K)上比随机性能更好的序列模型。

•注意基准。

我们测量了FlashAttention和基于序列长度的块稀疏FlashAttention的运行时和内存性能。我们确认FlashAttention的内存占用量与seq呈线性关系。长度,比普通序列的标准注意速度快3倍。长度(最多2K)。我们证实了块稀疏FlashAttention的运行时在seq中呈线性扩展。并且比所有现有的近似注意基线都快。其他实验细节见附录E。

4.1具有FlashAttention的更快型号

BERT。

FlashAttention产生了我们所知的最快的单节点BERT训练速度。我们在维基百科上用FlashAttention训练了一个BERT大模型。表1将我们的训练时间与Nvidia的训练时间进行了比较,Nvidia为MLPerf 1.1创造了训练速度记录[56]。我们的实施速度快了15%。

GPT-2。

与广泛使用的HuggingFace(84)和Megatron LM(74)实现相比,FlashAttention在大型OpenWebtext数据集(30)上为GPT-2(64)提供更快的训练时间。表2显示了与Huggingface相比高达3倍的端到端加速比,与 MegatronLM相比高达1.7倍的加速比。FlashAttention实现了与其他两个实现相同的困惑,因为我们没有更改模型定义。附录E包括了整个训练过程中验证困惑的曲线图,证实了FlashAttention在数字上与基线一样稳定,并产生了相同的训练/验证曲线

长程竞技场。

我们将vanilla Transformer(与标准实现或FlashAttention)在长程arena(LRA[77])基准上进行比较。我们测量所有模型的准确性、吞吐量和训练时间。每个任务的序列长度不同,介于1024到4096之间。我们遵循Tay等人[77]和Xiong等人[87]的实施和实验设置。3表3显示,与标准注意力相比,FlashAttention的速度提高了2.4倍。块稀疏FlashAttention比我们测试过的所有近似注意方法都要快。表3:标准注意力、FlashAttention、block稀疏FlashAttention和近似注意力基线在长期竞技场基准上的表现。

4.2序列较长的更好型号

具有长上下文的语言建模。

FlashAttention的运行时和内存效率使我们能够将GPT-2的上下文长度增加4倍,同时运行速度仍高于 MegatronLM的优化实现。表4显示,具有FlashAttention和上下文长度4K的GPT-2仍然比来自 Megatron的具有上下文长度1K的GPT-2快30%,同时获得了0.7更好的困惑度。

长文档分类。

使用FlashAttention对具有较长序列的Transformer进行训练,可以提高MIMIC-III[45]和ECtHR[6,7]数据集的性能。MIMIC-III包含重症监护室患者出院总结,每个总结都有多个标签。ECtHR包含来自

欧洲人权法院,每个法院都与据称违反《人权公约》的条款相关联。这两个数据集都包含很长的文本文档;MIMIC中的平均令牌数为2395个,最长文档包含14562个令牌,而ECtHR中的平均和最长数字分别为2197和49392。我们通过增加预训练RoBERTa模型的序列长度来评估扬程[54](我们重复位置嵌入,如Beltagy等人[3])。

表5显示,在MIMIC上,序列长度16K比长度512好4.3个点,在ECtHR上,长度8K比长度512好8.5个点。这种差异可能是由于细微的分布变化造成的:MIMIC-III包含专门的医学文本,因此可能更容易受到文件长度分布变化的影响,而ECtHR包含通用语言。

路径X和路径256。

Path-X和Path-256基准测试是设计用于测试长上下文的长程arena基准测试中具有挑战性的任务。任务是对黑白128×128(或256×256)图像中的两个点是否有连接它们的路径进行分类,并将图像一次一个像素地馈送给Transformer。在之前的工作中,所有Transformer模型要么内存不足,要么仅实现随机性能[77]。人们一直在寻找可以对如此长的上下文建模的替代体系结构[35]。在此,我们给出了Transformer模型能够求解Path-X和Path-256的第一个结果(表6)。我们在Path-64上预训练一个变换器,然后通过空间插值位置嵌入将其转移到Path-X。FlashAttention在Path-X上达到61.4的精度。此外,块稀疏FlashAttention使Transformer能够缩放到序列长度64K,在Path-256上达到63.1的精度4。

4.3对标注意事项

我们改变序列长度,测量FlashAttention的运行时和内存使用情况,并在一个具有40 GB HBM的A100 GPU上,根据不同的注意基线分块稀疏FlashAttention,该GPU带有dropout和填充掩码。我们将精确注意、近似注意和稀疏注意与参考实现进行比较。我们在正文中报告了基线的子集;附录E包含更多基线和全部细节。

运行时。

图3(左)以毫秒为单位报告了FlashAttention和block sparse FlashAttention向前+向后传递的运行时间,并与精确、近似和稀疏注意的基线进行了比较(精确数字见附录E)。运行时随序列长度呈二次增长,但FlashAttention的运行速度明显快于精确的注意基线,比PyTorch实现快3倍。许多近似/稀疏注意机制的运行时间随序列长度线性增长,但由于内存访问较少,对于短序列,FlashAttention仍然比近似和稀疏注意运行得更快。在512到1024之间的序列中,近似的注意力运行时开始与FlashAttention交叉。另一方面,在所有序列长度上,块稀疏FlashAttention比我们所知的精确、稀疏和近似注意的所有实现都要快。

内存占用。

图3(右)显示了FlashAttention和block sparse FlashAttention的内存占用,与各种精确、近似和稀疏注意基线进行了比较。FlashAttention和block sparse FlashAttention具有相同的内存占用,内存占用随序列长度线性增长。FlashAttention的内存效率比精确注意基线高20倍,并且比近似注意基线的内存效率更高。除Linformer之外的所有其他算法在64K之前都会在A100 GPU上耗尽内存,FlashAttention的效率仍然是Linformer的2倍。

5限制和未来方向

我们讨论了我们的方法的局限性和未来的方向。相关工作见附录A。

编译到CUDA。

我们当前构建注意的IO感知实现的方法需要为每个新的注意实现编写一个新的CUDA核。这需要用比PyTorch低得多的语言编写注意算法,并且需要大量的工程工作。实现也可能无法跨GPU体系结构进行传输。这些局限性表明,需要一种方法来支持用高级语言(如PyTorch)编写注意算法,并在CUDA中编译为IO感知的实现,类似于图像处理中的Halide[67]。

IO感知深度学习。

我们相信,IO感知方法可以超越注意。注意力是Transformer中内存最密集的计算,但深层网络中的每一层都涉及GPU HBM。我们希望我们的工作能够激励其他模块的IO感知实现。我们在附录D中讨论了这些潜在的扩展。

多GPU IO感知方法。

对于在单个GPU上计算注意力,我们的IO感知注意力实现在常量内是最佳的。然而,注意力计算可以跨多个GPU并行进行[69]。使用多个GPU为IO分析添加了一个额外的层,用于处理GPU之间的数据传输。我们希望我们的工作能够激励今后在这方面的工作。

确认书

我们的实现使用Apex的FMHA代码(https://github.com/NVIDIA/apex/tree/master/apex/作为起点。我们感谢Young Jun Ko对其FMHA实施的深入解释,以及他对我们有关CUDA问题的深思熟虑的回答。我们感谢Sabri Eyuboglu、Megan Leszczynski、Laurel Orr、Yuhuai Wu、Beidi Chen和Xun Huang对本文早期草稿的有益讨论和反馈。

我们衷心感谢NIH(编号:U54EB020405(Movement))、NSF(编号:CCF1763315(Beyond Sparsity))、CCF1563078(Volume to Velocity)和1937301(RTML)的支持;ARL,编号W911NF-21-2-0251(交互式人工智能团队);第N000141712266号ONR(统一薄弱监管);ONR N00014-20-1-2480:在机器学习中理解和应用非欧几里德几何;N000142012275(海王星);NXP、Xilinx、LETI-CEA、Intel、IBM、Microsoft、NEC、东芝、台积电、ARM、日立、巴斯夫、埃森哲、爱立信、高通、模拟设备、谷歌云、Salesforce、Total、HAI-GCP和HAI Azure云研究学分计划、斯坦福数据科学倡议(SDSI)、国防部(DoD)通过国防科学与工程研究生奖学金(NDSEG)计划,斯坦福黎明项目的成员:Facebook、谷歌和VMWare。美国政府有权出于政府目的复制和分发再版,尽管上面有任何版权标记。本材料中表达的任何观点、调查结果和结论或建议均为作者的观点、调查结果和结论或建议,不一定反映NIH、ONR或美国政府的观点、政策或明示或暗示的支持。Atri Rudra和Jessica Grogan的研究得到了NSF资助CCF-1763481的支持

A相关工作

IO感知运行时优化。

优化读写快/慢存储器的广义概念在计算机科学中有着悠久的历史,并被许多人所熟知。在这项工作中,我们与分析I/O复杂性的文献[1]有着最直接的联系,但内存层次结构的概念是基本的,并以多种形式出现,从工作集模型[20],到数据局部性[83],到算术强度的屋顶线模型[82],到可伸缩性分析[57],再到计算机体系结构的标准教科书处理[38]。我们希望这项工作能够鼓励社区在深度学习的更多部分采纳这些想法。

具有结构化矩阵的高效ML模型。

矩阵乘法是大多数机器学习模型的核心计算瓶颈。为了降低计算复杂性,有许多方法可以学习更有效的矩阵集。这些矩阵称为结构矩阵,具有次二次(( 2)对于尺寸 × ) 参数和运行时的数量。结构化矩阵最常见的例子是稀疏矩阵和低秩矩阵,以及信号处理中常见的快速变换(傅立叶、切比雪夫、正弦/余弦、正交多项式)。机器学习中提出了几种更一般的结构化矩阵:Toeplitz-like[75]、低位移秩[47]、拟可分[24])。我们用于块稀疏注意的蝴蝶模式的动机是,蝴蝶矩阵[14,61]及其产品已被证明能够表达任何具有几乎最佳运行时间和参数数量的结构化矩阵[15,19]。然而,尽管结构化矩阵在理论上是有效的,但它们并没有被广泛采用,因为很难将它们的效率转化为挂钟加速,因为密集无约束矩阵乘法具有非常优化的实现,这种现象被称为硬件彩票[39]。蝶形矩阵的扩展[16,17]旨在使蝶形矩阵更加硬件友好。

稀疏训练。

我们的block稀疏FlashAttention可以看作是使稀疏模型训练更有效的一步。稀疏模型通过对权重矩阵进行稀疏化,在压缩模型进行推理(修剪)方面取得了成功[22、36、37、53、73]。对于模型训练,彩票[26、27、28]表明,有一组小的子网络,它们来自一个更大的密集网络,其性能与原始的密集网络一样好。外块稀疏Flash注意力也可以被视为注意力上下文中的固定彩票:我们通过训练将稀疏模式固定为蝴蝶模式,并观察到它在长程竞技场任务中的表现几乎与(密集)Flash注意力一样好。

高效Transformer。

基于Transformer的模型正在成为自然语言处理[21]和计算视觉[23,88]中使用最广泛的架构。然而,它们的计算瓶颈之一是,它们的时间和内存在序列长度上是二次的。有许多方法可以克服这一瓶颈,包括使用哈希(即稀疏)等近似方法(如Reformer[49]和Smyrf[18])和低阶近似方法(如Performer[11,52])。人们甚至可以将稀疏近似和低秩近似结合起来以获得更好的精度(例如,Longferer[3]、BigBird[89]、Scatterbrain[8]、Long-short transformer[91]、Combiner[70])。其他方法包括沿序列维度压缩以同时处理多个令牌[50、55、76、86]。人们还可以注意之前序列中的状态,以帮助延长上下文(例如,Transformer XL[13]和Compression Transformer[66])。我们建议通过调查了解更多细节。

在开发其他模块方面有几条工作路线,而不是注意建模更长的上下文。HiPPO[33]及其扩展,尤其是S4[29,34,35]以多项式为基础预测历史,允许通过状态空间模型精确重建历史。它们结合了CNN(有效训练)、RNN(有效推理)和连续模型(对采样率变化具有鲁棒性)的优点。LambdaNetworks(2)、AFT(90)和FLASH(40)是在图像分类和语言建模的背景下替代注意力的其他尝试。

B算法详细信息

我们首先推导了注意力的向前和向后传递,并表明它们可以以一种高效的方式计算(需要额外的内存,序列长度是线性的,而不是二次的)。虽然它们减少了所需的额外内存量,但它们仍然会导致二次HBM访问,导致执行速度较慢。我们描述了FlashAttention算法,该算法可以在GPU上实现向前和向后传递,从而减少HBM访问,从而提高运行速度并减少内存占用。

B、 1个内存有效的前向传递

提高注意力内存效率的主要挑战是将K列(和V列)耦合在一起的softmax。我们的方法是分别计算softmax归一化常数以解耦列。文献[49,63]中使用了这项技术,以表明注意力计算不需要二次额外内存(尽管HBM访问的数量仍然是二次的,导致运行时缓慢)。

为简单起见,我们在此省略softmax期间的最大移位步骤。附录B.3中的完整算法包含所有步骤。

回想一下给定的输入序列Q、K、V∈ R × , 我们要计算注意力输出O∈ R × :

我们有 = 哪里 和 是-th和-Q和K的第th列。定义softmax的规范化常数:Let 成为-V的第列,然后是-输出的第th列为

我们见过一次 我们可以计算 无需重复求和的额外内存 . 因此,可以使用() 额外内存:

1、计算 对于所有人 根据式(1),其中() 额外的内存。

2、计算 对于所有人 根据式(2),其中() 额外的内存。

B、 2内存高效向后传递

我们推导了注意力的后向传递,并证明它也可以用线性内存计算。Rabe和Staats[63]认为,通过对内存高效的前向传球应用梯度检查点,可以在没有二次额外内存的情况下完成后向传球。相反,我们显式地导出向后传递,并展示如何以内存高效的方式计算向后传递。

假设存在标量损失函数, 让输出梯度为∈ R× (其中dO表示 O)。我们要计算输入梯度dQ,dK,dV∈ R× (其中dQ、dK、dV表示 Q K V)。

渐变dV很容易看到。通过手动应用反向模式自动微分(又名链式规则),我们得到(以矩阵表示法)dV=P 做因此:

因为我们已经计算过 , 可以通过重复求和在没有额外内存的情况下计算。梯度dQ和dK稍微复杂一些。我们先通过梯度dP和dS。根据公式(2),我们得到dP=dOV , 因此:

回想一下: = softmax软件(:). 利用 = softmax软件() is诊断() − , 我们有

哪里◦ 表示逐点乘法

现在我们可以得到梯度dQ和dK。回想一下 = , 所以

因此,也可以使用() 额外内存:

1、计算 对于所有人 根据式(3),其中() 额外的内存。

2、计算 对于所有人 根据式(4),其中() 额外的内存。

3、计算 对于所有人 根据式(5),其中() 额外的内存。

4、计算 对于所有人 根据式(6),其中() 额外的内存。

B、 3Flash注意:向前传球

我们描述了闪现注意力向前传球的全部细节。给定输入序列Q、K、V∈ R × , 我们要计算注意力输出O∈ R × :

哪里 ∈ R是一些softmax缩放(通常为1√ ), mask是一种屏蔽函数,用于将输入的某些条目设置为−∞ 并保持其他条目相同(例如,当批中的序列不具有相同的长度并且被填充时,密钥填充掩码),然后删除(, ) 将退出应用于 (即输出 1.− 概率为1− 并以概率输出0).

完整算法在算法2中。我们保存输出O,softmax统计信息ℓ 和, 以及反向传递的伪随机数生成器状态R。

B、 4Flash注意:向反向传播

我们描述了FlashAttention向反向传播的全部细节。给定输入序列Q、K、V∈ R × , 输出O∈ R × , 输出梯度dO,我们要计算输入梯度dQ,dK,dV∈ R × .

为了完整性,我们首先描述了标准的注意力向后传递算法3。

现在,我们对Flash注意力向反向传播进行两个观察:

1、我们不需要存储(2)向前传球。相反,我们可以保存前向过程中的伪随机数生成器状态,并在后向过程中重新生成丢失掩码。这只允许我们使用() 额外的内存。

2、计算softmax梯度时,我们使用公式(4)计算 = > : : 不减少超过: 和: 大小 (它们可能不适合SRAM)。相反,我们可以重写 = > 计算大小向量之间的点积.

算法4中提供了全Flash注意力反向传递算法。从概念上讲,它只是附录B.2中推导的块版本。

我们看到,与向前传球类似,向反向传播的表现( 2)FLOPs,仅需要() 输入、输出、输出梯度和输入梯度之外的额外内存。

我们分析了向后传递的IO复杂性,类似于向前传递(定理2)。定理5。允许 为序列长度, 为封头尺寸,以及 具有SRAM的大小 ≤ ≤ . 标准注意(算法0)向反向传播需要Θ( + 2)HBM访问,而FlashAttention反向传递(算法4)需要Θ( 2. 2.−1)HBM访问。

证明见附录C。

C证明

定理1的证明。

我们首先计算所需的FLOPs和额外内存的数量。

主要的FLOPs来自矩阵乘法。在内环中(算法1第9行),我们计算QK> ∈ R× 对于Q ∈ R× 和K ∈ R× , 这需要( ) FLOPs。我们还计算(算法1第12行)P° 五、 ∈ R× 对于P ∈ R× 和V ∈ R× , 这需要( ) FLOPs。我们执行内部循环 = l m l公司 m次。因此,FLOPs的总数为

就所需的额外内存而言,我们看到我们需要() 存储统计信息的内存(ℓ, ).

我们现在通过归纳证明了算法的正确性 对于0≤ ≤ . 让K: ∈ R× 成为第一个 K行和类似的V行: ∈ R× 第一个 V行。让S:,: = QK>: ∈ R × , 和P:,: = softmax(S:,:) ∈ R × (按行应用softmax)。允许 , ℓ( ) , O() 是的值, ℓ, 在HBM中,在-外循环的第次迭代(算法1第5行)。(请注意, ℓ, O在外循环每次迭代后更新。)我们想在-外环的第次迭代,我们在HBM中计算: ( ) = rowmax(S:,:) ∈ R , ℓ( ) = 行和(扩展:,: − ( ) )) ∈ R , O() = P: ,:五: ∈ R × .

根据我们的初始化(算法1第2行),此声明适用于 = 0(即,在执行外部循环的任何迭代之前)。假设该主张适用于 = 0, . . . , − 1、我们想证明该索赔也适用于 + 事实上,当我们在( + 1) -外循环的第次迭代,我们更新 ( +1) =最大值( ( ) , ✓)其中˜ ∈ R 是S的行最大值:,:+1,S从列的切片 至列( + 1) − 1、这意味着

同样,我们更新ℓ ( +1) = ( )−( +1) ℓ ( ) + ˜−( +1) ℓ,˜

哪里ℓ✓=行和(经验:,:+1.− ˜)) ∈ R . 通过第3.1节中相同的代数运算,我们得到:

然后我们看到,该声明也适用于 + 根据归纳法,该主张对所有人都是正确的 = 0, . . . , . 什么时候 = , 我们得出结论,HBM中O的最终值为softmax(S)V=softmax(QK>)V。

定理2的证明。

我们首先分析了标准注意实现的IO复杂性。输入Q、K、V∈ R × 驻留在HBM中,算法末尾的输出为∈ R × 已写入HBM

在计算矩阵乘法S=QK>的第一步中,输入Q、K从HBM读取,输出S∈ R × 写入HBM(算法0第1行)。这将导致( + 2)HBM接入。

在计算P=softmax(S)的第二步中,从HBM读取输入S,并将输出P写入HBM(算法0第2行)。这将导致( 2)HBM接入。

在计算O=PV的最后一步中,从全局内存读取输入P、V,并将输出O写入HBM(算法0第3行)。这将导致( + 2)HBM接入。

总体而言,标准注意实施需要( + 2)全局内存访问。

现在,我们分析流式注意力的IO复杂性。

在算法1之后,我们看到K和V的每个元素都从HBM加载了一次(算法1第6行)。我们制造 通过Q和O,每次通过都将所有Q和O加载到HBM(算法1第8行)。因此,HBM访问的数量为Θ( + ) = Θ( ).

我们推导了块大小的条件 和 . 我们需要块K 和V 大小 × 要安装到片上存储器中,可转换为:

同样,我们需要块Q , O 大小 × 要安装到片上存储器中

最后,我们需要块S 大小 × 要安装到片上存储器中,可转换为:

因此,HBM访问的数量为:

命题3的证明。

为了解决矛盾,假设存在一种计算精确注意的算法,其中所有HBM访问的数量 ∈ [, ] 是

在 = Θ( ), 这将导致HBM访问的数量:

然而,注意的输入(矩阵Q、K、V)和输出O都有大小 而且他们一开始是在HBM中,所以如果算法计算出精确的注意力,它必须引起至少Ω的注意( ) HBM访问。这是一种矛盾。

定理5的证明。

向后注意的IO复杂性与向前注意的IO复杂性非常相似(定理2)。这里我们提供了一个证明的草图。

我们首先分析了标准注意向后传递的IO复杂性。输入Q、K、V、dO∈ R × 驻留在HBM中,并在算法末尾输出dQ、dK、dV∈ R × 写入HBM。

在标准注意力向反向传播的每一步,都需要加载大小的输入 或 2来自HBM,需要写入大小的输出 2或 至HBM。这将导致( + 2)HBM接入。

现在,我们分析FlashAttention反向传递的IO复杂性。与定理2类似,我们看到K和V的每个元素都从HBM加载一次。dK和dV的每个元素只写入HBM一次。我们制造 通过Q、O、dO,每次通过将Q、O、dO全部加载到HBM。我们还制造 通过dQ,每次通过从HBM读取/写入所有dQ。因此,HBM访问的数量为Θ( + ) = Θ( ).

在定理2的证明中,对块大小的限制是:

因此,HBM访问的数量为:

D扩展详细信息

D、 1块稀疏FlashAttention

我们在算法5中描述了全块稀疏FlashAttention算法。该算法与算法2相同,只是跳过了零个块。

我们证明了块稀疏FlashAttention的IO复杂性。

命题4的证明。

这个证明与定理2的证明非常相似。对于块稀疏的情况,请注意,我们只需要加载对应于非零块的块。因此,HBM访问的数量按, 块稀疏遮罩中非零块的分数。但是,对于较小的值, 我们仍然需要将结果写入∈ R × . 因此,HBM访问的数量为

D、 2潜在扩展

我们在此讨论IO感知方法的一些潜在扩展,以加速深度学习训练。

多GPU注意事项。

大型语言模型在数百或数千个GPU上训练,其中一个通常在同一节点上的4-8个GPU之间分割注意力计算[74]。这引入了另一个级别的内存层次:除了GPU SRAM和GPU HBM之外,我们还有其他GPU的HBM。对于很长的序列,同一节点上的不同GPU可以通过考虑不同级别内存层次的不对称性来合作计算注意。

稀疏MLP层。

典型的密集MLP层受计算限制,而非内存限制。为了提高效率,可以使用具有稀疏权重矩阵的MLP层[16]。然而,许多稀疏MLP层是内存受限的,它们的加速比通常与稀疏性不成正比。我们相信IO感知的实现可以缓解这个问题,并实现稀疏性的好处。我们对这一方向的未来工作感到兴奋,以减少大型模型的计算需求并改进其墙块运行时。

核回归。

我们的FlashAttention方法依赖于以下事实: × 注意矩阵是低秩矩阵QK>的函数  ). 因此,我们可以重复加载输入Q、K,并重新计算所需的注意矩阵块,从而显著减少HBM访问。正如核回归中发生的类似场景:每个元素 的 × 核矩阵K是两个大小向量的函数 , 因为它测量两个数据点之间的相似性 和 . 我们希望我们的IO感知方法能够激励方法加速核回归。

E全部实验结果

E、 1BERT

我们按照参考MLPerf 1.1实现的训练程序和超参数对BERT large进行训练。特别是,我们使用学习率为3.75e-3、批量为448、最多可训练7100个步骤的LAMB优化器。一旦验证准确率(对于屏蔽语言建模)达到目标72.0%,并且测量了挂钟运行时间,则停止训练。我们使用Apex AMP(O2优化水平)进行FP16精度的训练。

我们将我们的结果与提交给MLPerf 1.1(表1)的Nvidia报告的训练速度进行了比较。

我们使用MLPerf 1.1参考实现提供的相同序列/验证数据分割。特别是,我们评估了与Nvidia基线相同的10000个验证示例。

我们在8×A100-80GB GPU上训练模型。每次训练跑需要16到19分钟,我们平均10次跑的结果。

E、 2 GPT-2

我们使用Huggingface transformers library和Nvidia MegatronLM repo的GPT-2标准实现。我们遵循 MegatronLM repo的训练配方。

我们使用512的有效批量大小,并使用梯度累积来适应可用的GPU内存。我们使用AdamW优化器,GPT-2 small的学习率为6e-4,GPT-2 medium的学习率为1.5e-4,权重衰减为0.1。所有模型都使用相同的超参数训练400K步。我们使用混合精度训练(PyTorch-AMP)运行所有实现。

我们使用Openwebtext数据集和GPT-2 BPE标记器。我们随机选择0.5%的数据集作为验证集,其余数据集作为训练集。验证集的随机选择只进行一次,所有模型都在同一个验证集上进行评估。

我们在8×A100-40GB GPU上对模型进行了训练,并测量了挂钟训练时间。GPT-2小型训练需要2.7-9.5天,而GPT-2中型训练需要6.9-21.0天(表2)。

在图4中,我们使用HuggingFace实现或FlashAttention实现绘制了整个GPT-2中小型训练过程中的验证困惑。我们看到FlashAttention的行为与基线实现相同,两个实现的验证复杂度曲线几乎位于彼此的顶部。

E、 3长文档分类

对于MIMIC-III和ECtHR,我们遵循Dai等人的超参数[12]。

E、 4 LRA

我们遵循长程竞技场论文[77]、长程竞技场回购(https://github.com/google research/Long-range-arena)和Nyströmformer复制[87]中的超参数。对于基线方法,如果我们无法再现五项任务中任何一项的任何基线性能,我们报告Tay等人[77]或Xiong等人[87]在该任务中的该基线性能更好。

在超参数调整之后,几乎所有的注意力方法在所有五个LRA任务上都达到了相似的精度。

我们使用混合精度训练运行所有方法,但执行者(混合精度不稳定)和局部注意(实现不支持FP16)除外。

为了计算总的挂钟时间加速比,我们取五个任务中每个任务的挂钟时间加速比的几何平均值。

路径-X

对于Path-X和Path-256,我们遵循长程arena论文中PathFinder-32实验的超参数[77]。对于这两种情况,我们首先在Path-64上预训练一个模型。我们在200个历元后获取检查点,对其位置嵌入进行上采样(我们在空间中以网格方式复制位置嵌入),并在下游任务中对其进行微调,其中一个历元为线性预热,学习率为余弦衰减。对于Path-X,我们采用性能最好的检查点(根据val准确度),并以相同的热身和学习速度对其进行200个时代的微调(这为Path-X的FlashAttention增加了大约4个准确度,但之后模型开始过度拟合)。

E、 5全面的基准测试结果

我们报告了完整的基准测试结果和实验细节。

基线

我们将PyTorch/HuggingFace和 Megatron的精确注意、近似注意和稀疏注意与参考实现进行比较。对于近似注意,我们将其与Reformer(49)、Local attention(65)、Linformer(81)、Smyrf(18)和LongShortFormer(LSFormer)(91)的参考实现进行比较。对于稀疏注意,我们将其与来自OpenAI[10]、Longferer[3]和BigBird attention[89]的块稀疏注意的参考实现进行比较。对于近似和稀疏注意,我们使用1/8的压缩比或256的压缩序列长度,以较小者为准。

安装程序

安装程序

我们在一台带有40 GB GPU HBM的A100 GPU的机器上,测量了8个维度为64的头和128个批量大小的注意力计算的运行时和内存使用情况。我们在实验中改变了序列长度。我们计算Q、K和V的随机向量的注意(我们不测量隐藏层的投影)。对于dropout,我们使用dropout0.1;对于掩蔽,我们使用填充掩蔽,其均匀随机的掩蔽长度介于总序列长度和总序列长度减20之间。为了测量运行时,我们对100次注意调用进行平均测量。我们只测量一次内存占用,因为它在运行之间没有变化。

我们报告了向前传球、向反向传播和向前+向后组合传球的计时结果。除了块稀疏、Longformer和BigBird之外,我们测量了每种方法,包括和不包括dropout、masking或两者。由于外部库中存在错误,这些方法没有成功运行带掩蔽的向后传递,因此我们在不带掩蔽的情况下对它们进行了测量,以确保其性能良好。我们对所有测量都使用FP16,但本地注意除外,它的实现只支持FP32。

对于每个基线,我们增加序列长度,直到GPU上的内存耗尽,但以下例外情况除外: Megatron实施不支持超过2048的序列长度。块稀疏(OpenAI)不支持长度超过4096的序列。Longformer和BigBird不支持长度超过8092的序列。

我们测量向前+向后组合过程中的内存使用情况,没有丢失或掩蔽。结果表7总结了所有实验配置,并包含指向结果表的指针

你可能感兴趣的:(FlashAttention)