flash attention 2论文学习

flash attention作者Tri Dao发布了flash attention 2,性能为flash attention的2倍。
优化点主要如下:

一、减少 non-matmul FLOPs

A00中由于tensor core的存在,使得gpu对于浮点矩阵运算吞吐很高,如FP16/BF16可以达到312 TFLOPs/s,而对于非矩阵乘的浮点运算吞吐较低,如FP32只有19.5 TFLOPs/s。因此作者调整算法以减少非矩阵乘的浮点运算。
如图1-1,基线算法计算O2的时候会对O1进行放缩,先乘上之前的sum L1,再除以新的sum L2。

在这里插入图片描述

图 1-1
但是这个其实没有必要,可以在最后一次计算只放缩一次,如图1-2。

在这里插入图片描述

图 1-2

二、并行模式

基线对于CTA的分块逻辑为启动batch_size * num_head个CTA,每个CTA执行一个batch里的一个head,那么当seq_len很长的场景,batch_size一般会比较小,这个时候无法充分利用所有的SM,所以作者调整了并行模型,一个batch里的一个head也会被多个CTA执行。
基线算法中外层循环是对K,内层循环对Q,作者交换了这个循环,对外层循环进行并发。
综合一,二之后的算法流程如图2-1

flash attention 2论文学习_第1张图片

图 2-1

三、warp分块

基线warp分块如图3-1,一个CTA所有warp都load Q,但是对K分块,这个时候计算S和P并没有啥问题,但是对计算O的时候,会导致warp之间对O执行一次reduce sum。

flash attention 2论文学习_第2张图片

图 3-1
因此作者调整了warp分块逻辑,如图3-2所示,对Q进行分块,每个warp都load K和V,以避免最后对O的reduce。

flash attention 2论文学习_第3张图片

图 3-2

你可能感兴趣的:(cuda,gpu,cuda)