速度飙升200%!Flash Attention 2一统江湖,注意力计算不再是问题!

❤️点击上方,选择星标置顶,每天给你送上干货❤️

作者 | godweiyang

出品 | 公众号:算法码上来(ID:GodNLP)

- BEGIN -

attention是Transformer中最重要的一个结构,但是随着序列长度的增加,计算复杂度以增长,显存和速度都会吃不消。因此很多attention加速算法被提了出来,例如flash attention、xformers等等。

就在7.17日,flash attention 2开源了,官方宣称比1代还要快2倍左右,于是我迫不及待就安装试了一下,看看到底有多大提升。
https://crfm.stanford.edu/2023/07/17/flash2.html

这次的测试对象有4个,分别是PyTorch手工实现的attention、torch.nn.functional提供的_scaled_dot_product_attention算子、flash attention 2官方实现、xformers官方实现。

直接说结论吧,大部分情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」

测试环境

  • A100-SXM4-80g,因为flash attention 2只支持A和H系列显卡。

  • PyTorch 1.13.1

  • CUDA 11.7

安装命令

pip install ninja triton

# flash attention
pip install flash-attn --no-build-isolation

# xformers
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

测试代码

import math
import random
import time
from einops import rearrange
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention, LowerTriangularMask


xformers_attn_bias = LowerTriangularMask()

def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

def pytorch_func(q, k, v, causal=False):
    o = F._scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
    return o

def flash_attention(q, k, v, causal=False):
    o = flash_attn_func(q, k, v, causal=causal)
    return o

def xformers_attention(q, k, v, causal=False):
    attn_bias = xformers_attn_bias if causal else None
    o = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
    return o

def test(func_name, q, k, v, *args, **kwargs):
    if func_name in ["custom_attention", "pytorch_func"]:
        q = rearrange(q, "a b c d -> a c b d")
        k = rearrange(k, "a b c d -> a c b d")
        v = rearrange(v, "a b c d -> a c b d")

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    for _ in range(5):
        o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    st = time.time()
    o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    tt = time.time() - st
    max_memory = torch.cuda.max_memory_allocated() // 2**20
    torch.cuda.empty_cache()

    if func_name in ["custom_attention", "pytorch_func"]:
        o = rearrange(o, "a c b d -> a b c d")

    return o, tt, max_memory

if __name__ == "__main__":
    test_num = 10
    for idx in range(test_num):
        print(f"test {idx} >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
        bsz = random.randint(1, 64)
        sql = random.randint(1, 4096)
        nh = random.choice([8, 12, 16])
        hd = random.choice([64, 128])
        dtype = random.choice([torch.float16, torch.bfloat16])
        causal = random.choice([False, True])
        print(f"shape: ({bsz}, {sql}, {nh}, {hd}), dtype: {dtype}, causal: {causal}")
        q = torch.randn((bsz, sql, nh, hd)).to("cuda:0", dtype)
        k = torch.rand_like(q)
        v = torch.rand_like(q)

        o, t, m = test("custom_attention", q, k, v, causal=causal)
        print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")

        pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=causal)
        print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
        assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
        
        fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
        print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
        assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)

        xf_o, xf_t, xf_m = test("xformers_attention", q, k, v, causal=causal)
        print(f"xformers time: {xf_t:.6f}, speedup: {t/xf_t:.2f}; peak memory: {xf_m} MB, save: {int((m-xf_m)/m*100)}%")
        assert torch.allclose(o, xf_o, rtol=1e-2, atol=1e-2)

测试结果

测试了10组随机输入shape(batch_size, seq_len, num_head, head_dim),随机开启causal mask,随机fp16或bf16,结果如下:

test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000734, peak memory: 105 MB
pytorch func time: 0.000104, speedup: 7.06; peak memory: 49 MB, save: 53%
flash attention time: 0.000055, speedup: 13.45; peak memory: 43 MB, save: 59%
xformers time: 0.000152, speedup: 4.82; peak memory: 61 MB, save: 41%
test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (57, 1235, 8, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.015195, peak memory: 3093 MB
pytorch func time: 0.001930, speedup: 7.87; peak memory: 571 MB, save: 81%
flash attention time: 0.000635, speedup: 23.94; peak memory: 496 MB, save: 83%
xformers time: 0.001383, speedup: 10.99; peak memory: 696 MB, save: 77%
test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (61, 2045, 16, 128), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.101898, peak memory: 18782 MB
pytorch func time: 0.031511, speedup: 3.23; peak memory: 4115 MB, save: 78%
flash attention time: 0.005292, speedup: 19.25; peak memory: 3560 MB, save: 81%
xformers time: 0.009730, speedup: 10.47; peak memory: 3972 MB, save: 78%
test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (15, 1526, 12, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.010720, peak memory: 3756 MB
pytorch func time: 0.001101, speedup: 9.74; peak memory: 1732 MB, save: 53%
flash attention time: 0.000380, speedup: 28.24; peak memory: 1211 MB, save: 67%
xformers time: 0.000862, speedup: 12.43; peak memory: 824 MB, save: 78%
test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (28, 3227, 12, 128), dtype: torch.float16, causal: True
custom pytorch time: 0.091987, peak memory: 15090 MB
pytorch func time: 0.029867, speedup: 3.08; peak memory: 2223 MB, save: 85%
flash attention time: 0.004636, speedup: 19.84; peak memory: 1924 MB, save: 87%
xformers time: 0.008405, speedup: 10.94; peak memory: 2151 MB, save: 85%
test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (37, 2047, 8, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.026797, peak memory: 6242 MB
pytorch func time: 0.003424, speedup: 7.83; peak memory: 1388 MB, save: 77%
flash attention time: 0.000947, speedup: 28.29; peak memory: 1049 MB, save: 83%
xformers time: 0.002072, speedup: 12.93; peak memory: 1006 MB, save: 83%
test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (24, 2637, 16, 128), dtype: torch.bfloat16, causal: False
custom pytorch time: 0.053066, peak memory: 11970 MB
pytorch func time: 0.047200, speedup: 1.12; peak memory: 2205 MB, save: 81%
flash attention time: 0.006308, speedup: 8.41; peak memory: 1885 MB, save: 84%
xformers time: 0.011971, speedup: 4.43; peak memory: 2055 MB, save: 82%
test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (37, 3214, 12, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.097363, peak memory: 19552 MB
pytorch func time: 0.012316, speedup: 7.91; peak memory: 2142 MB, save: 89%
flash attention time: 0.003399, speedup: 28.65; peak memory: 1720 MB, save: 91%
xformers time: 0.007016, speedup: 13.88; peak memory: 1995 MB, save: 89%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (40, 2126, 16, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.066542, peak memory: 12737 MB
pytorch func time: 0.006069, speedup: 10.96; peak memory: 1856 MB, save: 85%
flash attention time: 0.002226, speedup: 29.89; peak memory: 1516 MB, save: 88%
xformers time: 0.004234, speedup: 15.71; peak memory: 1840 MB, save: 85%
test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (47, 3355, 12, 64), dtype: torch.bfloat16, causal: False
custom pytorch time: 0.100385, peak memory: 26267 MB
pytorch func time: 0.024839, speedup: 4.04; peak memory: 2353 MB, save: 91%
flash attention time: 0.008755, speedup: 11.47; peak memory: 1956 MB, save: 92%
xformers time: 0.016346, speedup: 6.14; peak memory: 2483 MB, save: 90%

可以看出,在大多数情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」

而且几个方法的API都非常好用,基本可以直接替换你自己模型里的attention模块。但是flash attention 2貌似不支持传入attention mask,只能指定causal mask,因此有一定的局限性,用在gpt里还是足够了。

- END -

我是godweiyang,字节跳动AI Lab NLP算法工程师,华师计算机本硕均专业第一,擅长算法模型优化机器翻译

回复【算法

获取我面试时写过的100多道算法题解,刷完进大厂没问题。

回复【CUDA

获取我为新手准备的CUDA入门系列教程。

回复【内推

内推字节,通过率高,加我微信可随时催进度咨询问题

回复【加群

进我的技术交流(聊天)群和内推群,群内有字节HR答疑

速度飙升200%!Flash Attention 2一统江湖,注意力计算不再是问题!_第1张图片

求求兄弟们点个在看吧,今天的阅读量靠你们了

你可能感兴趣的:(速度飙升200%!Flash Attention 2一统江湖,注意力计算不再是问题!)