❤️点击上方,选择星标或置顶,每天给你送上干货❤️
作者 | godweiyang
出品 | 公众号:算法码上来(ID:GodNLP)
- BEGIN -
attention是Transformer中最重要的一个结构,但是随着序列长度的增加,计算复杂度以增长,显存和速度都会吃不消。因此很多attention加速算法被提了出来,例如flash attention、xformers等等。
就在7.17日,flash attention 2开源了,官方宣称比1代还要快2倍左右,于是我迫不及待就安装试了一下,看看到底有多大提升。
https://crfm.stanford.edu/2023/07/17/flash2.html
这次的测试对象有4个,分别是PyTorch手工实现的attention、torch.nn.functional
提供的_scaled_dot_product_attention
算子、flash attention 2官方实现、xformers官方实现。
直接说结论吧,大部分情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」。
A100-SXM4-80g,因为flash attention 2只支持A和H系列显卡。
PyTorch 1.13.1
CUDA 11.7
pip install ninja triton
# flash attention
pip install flash-attn --no-build-isolation
# xformers
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
import math
import random
import time
from einops import rearrange
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention, LowerTriangularMask
xformers_attn_bias = LowerTriangularMask()
def custom_attention(q, k, v, causal=False):
score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
if causal:
mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
mask = mask.to(q.device, q.dtype)
score = score + mask
attn = F.softmax(score, dim=-1)
o = torch.matmul(attn, v)
return o
def pytorch_func(q, k, v, causal=False):
o = F._scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
return o
def flash_attention(q, k, v, causal=False):
o = flash_attn_func(q, k, v, causal=causal)
return o
def xformers_attention(q, k, v, causal=False):
attn_bias = xformers_attn_bias if causal else None
o = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
return o
def test(func_name, q, k, v, *args, **kwargs):
if func_name in ["custom_attention", "pytorch_func"]:
q = rearrange(q, "a b c d -> a c b d")
k = rearrange(k, "a b c d -> a c b d")
v = rearrange(v, "a b c d -> a c b d")
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
for _ in range(5):
o = globals()[func_name](q, k, v, *args, **kwargs)
torch.cuda.synchronize()
st = time.time()
o = globals()[func_name](q, k, v, *args, **kwargs)
torch.cuda.synchronize()
tt = time.time() - st
max_memory = torch.cuda.max_memory_allocated() // 2**20
torch.cuda.empty_cache()
if func_name in ["custom_attention", "pytorch_func"]:
o = rearrange(o, "a c b d -> a b c d")
return o, tt, max_memory
if __name__ == "__main__":
test_num = 10
for idx in range(test_num):
print(f"test {idx} >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
bsz = random.randint(1, 64)
sql = random.randint(1, 4096)
nh = random.choice([8, 12, 16])
hd = random.choice([64, 128])
dtype = random.choice([torch.float16, torch.bfloat16])
causal = random.choice([False, True])
print(f"shape: ({bsz}, {sql}, {nh}, {hd}), dtype: {dtype}, causal: {causal}")
q = torch.randn((bsz, sql, nh, hd)).to("cuda:0", dtype)
k = torch.rand_like(q)
v = torch.rand_like(q)
o, t, m = test("custom_attention", q, k, v, causal=causal)
print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=causal)
print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
xf_o, xf_t, xf_m = test("xformers_attention", q, k, v, causal=causal)
print(f"xformers time: {xf_t:.6f}, speedup: {t/xf_t:.2f}; peak memory: {xf_m} MB, save: {int((m-xf_m)/m*100)}%")
assert torch.allclose(o, xf_o, rtol=1e-2, atol=1e-2)
测试了10组随机输入shape(batch_size, seq_len, num_head, head_dim),随机开启causal mask,随机fp16或bf16,结果如下:
test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000734, peak memory: 105 MB
pytorch func time: 0.000104, speedup: 7.06; peak memory: 49 MB, save: 53%
flash attention time: 0.000055, speedup: 13.45; peak memory: 43 MB, save: 59%
xformers time: 0.000152, speedup: 4.82; peak memory: 61 MB, save: 41%
test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (57, 1235, 8, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.015195, peak memory: 3093 MB
pytorch func time: 0.001930, speedup: 7.87; peak memory: 571 MB, save: 81%
flash attention time: 0.000635, speedup: 23.94; peak memory: 496 MB, save: 83%
xformers time: 0.001383, speedup: 10.99; peak memory: 696 MB, save: 77%
test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (61, 2045, 16, 128), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.101898, peak memory: 18782 MB
pytorch func time: 0.031511, speedup: 3.23; peak memory: 4115 MB, save: 78%
flash attention time: 0.005292, speedup: 19.25; peak memory: 3560 MB, save: 81%
xformers time: 0.009730, speedup: 10.47; peak memory: 3972 MB, save: 78%
test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (15, 1526, 12, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.010720, peak memory: 3756 MB
pytorch func time: 0.001101, speedup: 9.74; peak memory: 1732 MB, save: 53%
flash attention time: 0.000380, speedup: 28.24; peak memory: 1211 MB, save: 67%
xformers time: 0.000862, speedup: 12.43; peak memory: 824 MB, save: 78%
test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (28, 3227, 12, 128), dtype: torch.float16, causal: True
custom pytorch time: 0.091987, peak memory: 15090 MB
pytorch func time: 0.029867, speedup: 3.08; peak memory: 2223 MB, save: 85%
flash attention time: 0.004636, speedup: 19.84; peak memory: 1924 MB, save: 87%
xformers time: 0.008405, speedup: 10.94; peak memory: 2151 MB, save: 85%
test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (37, 2047, 8, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.026797, peak memory: 6242 MB
pytorch func time: 0.003424, speedup: 7.83; peak memory: 1388 MB, save: 77%
flash attention time: 0.000947, speedup: 28.29; peak memory: 1049 MB, save: 83%
xformers time: 0.002072, speedup: 12.93; peak memory: 1006 MB, save: 83%
test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (24, 2637, 16, 128), dtype: torch.bfloat16, causal: False
custom pytorch time: 0.053066, peak memory: 11970 MB
pytorch func time: 0.047200, speedup: 1.12; peak memory: 2205 MB, save: 81%
flash attention time: 0.006308, speedup: 8.41; peak memory: 1885 MB, save: 84%
xformers time: 0.011971, speedup: 4.43; peak memory: 2055 MB, save: 82%
test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (37, 3214, 12, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.097363, peak memory: 19552 MB
pytorch func time: 0.012316, speedup: 7.91; peak memory: 2142 MB, save: 89%
flash attention time: 0.003399, speedup: 28.65; peak memory: 1720 MB, save: 91%
xformers time: 0.007016, speedup: 13.88; peak memory: 1995 MB, save: 89%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (40, 2126, 16, 64), dtype: torch.float16, causal: True
custom pytorch time: 0.066542, peak memory: 12737 MB
pytorch func time: 0.006069, speedup: 10.96; peak memory: 1856 MB, save: 85%
flash attention time: 0.002226, speedup: 29.89; peak memory: 1516 MB, save: 88%
xformers time: 0.004234, speedup: 15.71; peak memory: 1840 MB, save: 85%
test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (47, 3355, 12, 64), dtype: torch.bfloat16, causal: False
custom pytorch time: 0.100385, peak memory: 26267 MB
pytorch func time: 0.024839, speedup: 4.04; peak memory: 2353 MB, save: 91%
flash attention time: 0.008755, speedup: 11.47; peak memory: 1956 MB, save: 92%
xformers time: 0.016346, speedup: 6.14; peak memory: 2483 MB, save: 90%
可以看出,在大多数情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」。
而且几个方法的API都非常好用,基本可以直接替换你自己模型里的attention模块。但是flash attention 2貌似不支持传入attention mask,只能指定causal mask,因此有一定的局限性,用在gpt里还是足够了。
- END -
我是godweiyang,字节跳动AI Lab NLP算法工程师,华师计算机本硕均专业第一,擅长算法、模型优化和机器翻译。
回复【算法】
获取我面试时写过的100多道算法题解,刷完进大厂没问题。
回复【CUDA】
获取我为新手准备的CUDA入门系列教程。
回复【内推】
内推字节,通过率高,加我微信可随时查催进度、咨询问题。
回复【加群】
进我的技术交流(聊天)群和内推群,群内有字节HR答疑。
求求兄弟们点个在看吧,今天的阅读量靠你们了