目录
前言
第一部分 相比第一代的改进点:FlashAttention与Multi-Query Attention
第二部分 FlashAttention:减少内存访问提升计算速度——更长上下文的关键
2.1 FlashAttention相关的背景知识
2.1.1 Transformer计算复杂度:编辑——Self-Attention层与MLP层
2.1.1.1 Self-Attention层的计算复杂度:
2.1.1.2 MLP层的计算复杂度:
2.1.2 Transformer的空间复杂度:编辑——Self-Attention层与MLP层
2.1.2.1 Self-Attention块的中间激活:
2.1.2.2 MLP块的中间激活:
2.1.2.3 两个layer norm需要保存的中间激活:
2.1.3 分析GPU的内存分析图:计算的瓶颈是显存访问
2.1.4 safe softmax
2.2 前向传递:Standard Attention/Memory-efficient Attention/Flash Attention
2.2.1 Standard Attention
2.2.2 Memory-efficient Attention:把显存复杂度从平方降低到线性,但HBM访问次数仍是平方
2.2.3 Flash Attention:避免频繁地从HBM中读写数据
第三部分 多查询注意力(Muti Query Attention):各自Query矩阵,但共享Key 和 Value 矩阵
3.1 Multi-Head Attention、Grouped-Query Attention、Muti Query Attention的区别
3.2 MHA 和 MQA在代码实现上的差异
第四部分 模型的使用/部署、微调
4.1 模型的使用/部署
4.2 基于 P-Tuning v2 的微调(官方
本文最初和第一代ChatGLM-6B的内容汇总在一块,但为了阐述清楚FlashAttention、Multi-Query Attention等相关的原理,以及GLM2的微调、源码解读等内容,导致之前那篇文章越写越长,故特把ChatGLM2相关的内容独立抽取出来成本文
且本文会和本博客内其他大模型相关的文章一样,极其注重可读性,比如为了不断提高可读性,本文近期会不断反复修改,细抠标题的层级、措辞,甚至排版、标点符号,如果不通俗易懂,宁愿不写
ChatGLM2-6B(GitHub项目地址、HuggingFace地址)是开源中英双语对话模型 ChatGLM-6B 的第二代版本,相比第一代,第二点引入了如下新特性:
context_layer 这个函数实现了attention机制的计算,入参 is_causal=True 表示遮后看前的mask(这种类型的注意力通常用在transformer的decoder部分,以确保当前位置只能关注到之前的位置,俗称“看不见未来”,从而使模型可以进行自回归预测 )
FlashAttention是斯坦福联合纽约州立大学在22年6月份提出的一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法「对应论文为:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,这是其GitHub地址,这是其解读之一,该解读也是本第二部分的重要参考之一」
它要解决一个什么样的问题呢?
首先,GPT3、LLaMA、ChatGLM、BLOOM等大语言模型输入输出的最大序列长度只有2048或4096,扩展到更长序列的难度在哪里呢?本质原因是,transformer模型的计算复杂度和空间复杂度都是 的,其中
为序列长度
当输入批次大小为 ,序列长度为
时,
层transformer模型的计算量为
,
是隐藏层维度通常等于词向量维度,可能不少同学都会疑问这个计算量是怎么一步一步计算得来的,下面详细拆解下这个计算过程
首先,我们知道,transformer模型由 个相同的层组成,每个层分为两部分:self-attention块和MLP块
self-attention层的模型参数有两部分,一部分是、
、
的权重矩阵
、
、
和偏置,另一部分是输出权重矩阵
和偏置,最终为
具体怎么计算得来的呢?
MLP块由2个线性层组成,最终是
怎么计算得来的呢?
- 一般地,第一个线性层是,第二个线性层再将维度从
映射到
第一个线性层的权重矩阵 的形状为
,相当于先将维度从
映射到
,矩阵乘法的输入和输出形状为
,计算量为
第二个线性层的权重矩阵 的形状为
,相当于再将维度从
映射到
,矩阵乘法的输入和输出形状为
,计算量为
- 将上述所有表粗所示的计算量相加,得到每个transformer层的计算量大约为
- 此外,另一个计算量的大头是logits的计算(毕竟词嵌入矩阵的参数量也较多),将隐藏向量映射为词表大小,说白了,词向量维度通常等于隐藏层维度
,词嵌入矩阵的参数量为
,最后的输出层的权重矩阵通常与词嵌入矩阵是参数共享的「解释一下,如七月杜老师所说,这个是transformer中一个重要的点,参数共享可以减小参数量,词嵌入矩阵是[vocab_size,hidden_size],输出层矩阵是 [hidden_size,vocab_size],是可以共享的」
其矩阵乘法的输入和输出形状为,计算量为
- 因此,对于一个
层的transformer模型,输入数据形状为
的情况下,一次训练迭代的计算量为
中间激活的显存大小为 ,其中
为注意力头数
大模型在训练过程中通常采用混合精度训练,中间激活值一般是float16或者bfloat16数据类型的。在分析中间激活的显存占用时,假设中间激活值是以float16或bfloat16数据格式来保存的,每个元素占了2个bytes。唯一例外的是,dropout操作的mask矩阵,每个元素只占1个bytes。在下面的分析中,单位是bytes,而不是元素个数。
每个transformer层包含了一个self-attention块和MLP块,并分别对应了一个layer normalization连接。
self-attention块的计算公式如下:
最终,self-attention块的中间激活占用显存大小为:
具体怎么计算得来的呢?
- 对于
,需要保存它们共同的输入
,这就是中间激活。输入
的形状为
,元素个数为
,占用显存大小为
- 对于
矩阵乘法,需要保存中间激活
,两个张量的形状都是
,占用显存大小合计为
- 对于
函数,需要保存函数的输入
,占用显存大小为
,这里的
表示注意力头数
其中的形状为:
的形状为:
的形状为:
,元素个数为
,占用显存大小为
- 计算完
函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与
相同,占用显存大小为
- 计算在
上的attention,即
,需要保存
,大小为
;以及
,大小为
,二者占用显存大小合计为
- 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为
;dropout需要保存mask矩阵,大小为
,二者占用显存大小合计为
因此,将上述中间激活相加得到,self-attention块的中间激活占用显存大小为
MLP块的计算公式如下:,最终对于MLP块,需要保存的中间激活值为
具体怎么计算得来的呢?
- 第一个线性层需要保存其输入,占用显存大小为
- 激活函数需要保存其输入,占用显存大小为
- 第二个线性层需要保存其输入,占用显存大小为
- 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为
另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为,2个layer norm需要保存的中间激活为
综上,每个transformer层需要保存的中间激活占用显存大小为
对于
层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度
比较大,层数
较深时,这部分的中间激活是很少的,可以忽略
因此,对于
层transformer模型,中间激活占用的显存大小可以近似为
「更多分析见此文《分析transformer模型的参数量、计算量、中间激活、KV cache》」
通过上面两小节的内容,可以看到,transformer模型的计算量和储存复杂度随着序列长度 呈二次方增长。这限制了大语言模型的最大序列长度
的大小
其次,GPT4将最大序列长度 扩大到了32K,Claude更是将最大序列长度
扩大到了100K,这些工作一定采用了一些优化方法来降低原生transformer的复杂度,那具体怎么优化呢?
我们知道,每个transformer层分为两部分:self-attention块和MLP块,但上面计算量中的 项和中间激活中的
项都是self-attention块产生的,与MLP块无关
如此,FlashAttention提出了一种加速计算、节省显存和IO感知的精确注意力,可以有效地缓解上述问题
Meta推出的开源大模型LLaMA,阿联酋推出的开源大模型Falcon都使用了Flash Attention来加速计算和节省显存。目前,Flash Attention已经集成到了pytorch2.0中,另外triton、xformer等开源框架也进行了整合实现
通过上文可知,transformer的核心组件self-attention块的计算复杂度和空间复杂度是序列长度 的二次方
GPU的内存由多个不同大小和不同读写速度的内存组成。内存越小,读写速度越快。对于A100-40GB来说,内存分级图如下所示
所以,上面讲到计算注意力的主要瓶颈是显存访问,因此减少对HBM的读写次数,有效利用更高速的SRAM来进行计算是非常重要的,而GPU有大量的线程来执行某个操作,称为kernel。GPU执行操作的典型方式分为三步:
而对于性能受限于内存带宽的操作,进行加速的常用方式就是kernel融合。kernel融合的基本思想是:避免反复执行“从HBM中读取输入数据,SRAM执行计算,最后将计算结果写入到HBM中”,将多个操作融合成一个操作,减少读写HBM的次数(需要注意的是,模型训练通常会影响到算子融合的效果,因为为了后向传递计算梯度,通常需要将某些中间结果写入到HBM中)
继续行文之前,先补充两个背景知识,一个是safe softmax,一个是Standard Attention
对于第一个背景知识:safe softmax而言
在注意力计算过程中,节省显存的主要挑战是softmax与的列是耦合的。其方法是单独计算softmax的归一化因子,来实现解耦
如此,节省内存(memory-efficient)的注意力机制,改变了计算顺序,相比于Standard Attention,节省显存的注意力机制将显存复杂度从 降低到了
这种方法在《Online normalizer calculation for softmax》和《Self-attention Does Not Need Memory》中已经使用过,称其为“lazy softmax”,这种方法避免了实例化完整的注意力矩阵
,从而达到了节省显存的目的。然而HBM访问次数仍然是
的,因此运行时间并没有减少
// 待更..
多查询注意力(Muti Query Attention)是 19 年Google一研究者提出的一种新的 Attention 机制(对应论文为:Fast Transformer Decoding: One Write-Head is All You Need、这是其解读之一),其能够在保证模型效果的同时加快 decoder 生成 token 的速度
那其与17年 Google提出的transformer中多头注意力机制(简称MHA)有啥本质区别呢?有意思的是,区别在于:
下图对比了多头注意力(Multi-Head Attention)、LLaMA2中分组查询注意力(Grouped-Query Attention)、多查询注意力(Muti Query Attention)的差别
总之,MHA 和 MQA 之间的区别只在于建立 Wqkv Layer 上
# Multi Head Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
device=device
)
query, key, value = qkv.chunk( # 【关键】每个 tensor 都是 (1, 512, 768)
3,
dim=2
)
# Multi Query Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_model
device=device, # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
)
对比上面的代码,你可以发现
因此,可以确认:在 MQA 中,除了 query 向量还保存着 8 个头,key 和 value 向量都只剩 1 个「公共头」了,这也正好印证了论文中所说的「所有 head 之间共享一份 key 和 value 的参数」
剩下的问题就是如何将这 1 份参数同时让 8 个头都使用,代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享:
def scaled_multihead_dot_product_attention(
query,
key,
value,
n_heads,
multiquery=False,
):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) # (1, 512, 768) -> (1, 8, 512, 96)
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery
# (1, 512, 96) -> (1, 1, 96, 512) if multiquery
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery
# (1, 512, 96) -> (1, 1, 512, 96) if multiquery
attn_weight = q.matmul(k) * softmax_scale # (1, 8, 512, 512)
attn_weight = torch.softmax(attn_weight, dim=-1) # (1, 8, 512, 512)
out = attn_weight.matmul(v) # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
out = rearrange(out, 'b h s d -> b s (h d)') # (1, 512, 768)
return out, attn_weight, past_key_value
git clone https://github.com/THUDM/ChatGLM2-6B
cd ChatGLM2-6B
pip install -r requirements.txt
4.30.2
,torch
推荐使用 2.0 及以上的版本,以获得最佳的推理性能>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
>>> print(response)
从本地加载模型
以上代码会由 transformers 自动下载模型实现和参数
完整的模型实现在 Hugging Face Hub。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。
从 Hugging Face Hub 下载模型需要先安装Git LFS,然后运行
git clone https://huggingface.co/THUDM/chatglm2-6b
如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm2-6b
然后从这里手动下载模型参数文件,并将下载的文件替换到本地的
chatglm2-6b
目录下将模型下载到本地之后,将以上代码中的
THUDM/chatglm2-6b
替换为你本地的chatglm2-6b
文件夹的路径,即可从本地加载模型。模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以在
from_pretrained
的调用中增加revision="v1.0"
参数。v1.0
是当前最新的版本号,完整的版本列表参见 Change Log
最后,可以通过以下命令启动基于 Gradio 的网页版 demo:
python web_demo.py
P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行(当然,我司杜老师也会在七月类ChatGPT微调实战课上录一个ChatGLM2-6B的微调视频)
{ “content”: “类型#上衣版型#宽松版型#显瘦图案#线条衣样式#衬衫衣袖型#泡泡袖衣款式#抽绳”, “summary”:
“这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。”
}
bash train.sh
微调过程显存使用情况如下:
微调完成后,在./output/adgen-chatglm2-6b-pt-128-2e-2 下回生成微调好的模型文件。
我们可以对比下微调前后的效果
以命令行 Demo为例,只需修改ptuning路径下web_demo.sh中的模型路径为/data/sim_chatgpt/chatglm2-6b,运行 web_demo.py即可:
bash web_demo.sh
Input:
类型#上衣材质#牛仔布颜色#白色风格#简约图案#刺绣衣样式#外套衣款式#破洞
Label:
简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
Output[微调前]:
Output[微调后]:
// 待更