chatGLM2中的Multi Query Attention

目录

原理简介

代码实现和耗时比较

总结分析

          近期一直在玩大模型,对中文支持比较好的就是清华的chatGLM,目前chatGLM由v1升级到了chatGLM2。在gihub上介绍信息如下:

chatGLM2中的Multi Query Attention_第1张图片

 试用了一下,效果和速度确实有所提升。

chatGLM2中的Multi Query Attention_第2张图片

 这个得益于chatGLM2应用了许多优化的技术,介绍中有提到过的FlashAttention技术、Multi Query Attention(MQA)技术和int4量化等等。其中MQA技术是对Multi head  Attention(MHA)的一种优化实现,加快了技术速度的同时也保证了效果下降的不厉害。

原理简介

       MQA最早是出现在2019年谷歌的一篇论文Fast Transformer Decoding: One Write-Head is All You Need,之所以没有关注到,是因为之前很少做文本生成,解码序列长度也没有现阶段大模型的要求那么高。MQA的思想其实比较简单(如果对MHA比较熟悉的话),论文中给出的描述如下:

论文的意思是:MQA和MHA除了不同的attention head共享一份keys和values权重之外,其他的都是一样的。现有4个head的attention,每个head分别进行softmax(QK)V注意力计算,那么这样设置的MHA和MQA示意图如下所示:

chatGLM2中的Multi Query Attention_第3张图片

 chatGLM2中的Multi Query Attention_第4张图片

 可以看到MHQ和MQA的不同之处仅仅在于每个头共享相同的K、V权重而Q不同享。

模型效果论文对比如下:

chatGLM2中的Multi Query Attention_第5张图片

 推理速度上生成一个token时MHA和MQA的encoder分别耗时1.7us和1.5us,而decoder分别46us和3.8us,说明decoder上MQA比MHA快很多。另外在效果上MQA的PPL(越小越好)有所上升,BLEU(越大越好)有所下降,换句话说就是效果有所下降。

代码实现和耗时比较

参考了huggingface的transformers包中的bertselfattention源码实现了一版MHA和MQA,代码如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import math
import torch.nn as nn
import torch
from tqdm import tqdm
import time
class MiltiHeadSelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(0.1)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self,hidden_states):
        mixed_query_layer = self.query(hidden_states)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return context_layer


class MultiQuerySelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.attention_head_size)
        self.value = nn.Linear(hidden_size, self.attention_head_size)

        self.dropout = nn.Dropout(0.1)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self,hidden_states):
        # hidden_states (B, L, D)
        mixed_query_layer = self.query(hidden_states)
        # query_layer  (B, h, L, d)
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 每个key、value head参数都是一样的,只计算一次
        key = self.key(hidden_states)
        #key_layer  (B, 1, L, d)
        key_layer = key.unsqueeze(1)
        value = self.value(hidden_states)
        # value_layer  (B, 1, L, d)
        value_layer = value.unsqueeze(1)

        # key_layer  (B, 1, d, L)
        key_layer = key_layer.transpose(-1, -2)
        #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L)
        attention_scores = torch.matmul(query_layer, key_layer)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d)
        context_layer = torch.matmul(attention_probs, value_layer)
        #(B, h, L, d) => (B, L, h, d)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # (B,L, h*d) => (B,L,D)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        # (B,L, h*d) => (B,L,D)
        context_layer = context_layer.view(new_context_layer_shape)
        return context_layer




if __name__ == '__main__':
    seed = 100
    num_attention_heads, hidden_size = 32, 4096
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    device = "cuda:0"

    embeddings = torch.randn(5, 128, hidden_size).to(device)

    multiquery = MultiQuerySelfAttention(num_attention_heads, hidden_size).to(device)
    print(multiquery)
    total = 0
    for name, param in multiquery.named_parameters():
        if len(param.size()) == 2:
            total += param.shape[0] * param.shape[1]
        else:
            total += param.shape[0]
    print(f"multiquery parameters {total}")
    count = 100
    start = time.time()
    for _ in tqdm(range(count),ncols=50):
        input = embeddings.clone()
        for _ in range(100):
            for i in range(24):
                ouput = multiquery(input)
            input = torch.cat([input,ouput[:,-1:,:]],dim=1)
    end = time.time()
    print(f"multiquery time total cost {round(end - start, 8)} mean cost {round((end - start) / count, 8)}")


    multihead = MiltiHeadSelfAttention(num_attention_heads, hidden_size).to(device)
    print(multihead)
    total = 0
    for name, param in multihead.named_parameters():
        if len(param.size()) == 2:
            total += param.shape[0] * param.shape[1]
        else:
            total += param.shape[0]
    print(f"multihead parameters {total}")
    count = 100
    start = time.time()
    for _ in tqdm(range(count) ,ncols=50):
        input = embeddings.clone()
        for _ in range(100):
            for i in range(24):
                ouput = multihead(input)
            input = torch.cat([input, ouput[:, -1:, :]], dim=1)
    end = time.time()
    print(f"multihead time total cost {round(end-start,8)} mean cost {round((end-start)/count,8)}")

实现中主要借助矩阵计算的broadcast机制(自动广播机制)并行计算、就不用自己来实现每个头单独计算然后进行cat操作,效率比较高。模拟chatGLM2的设置,hidden_size = 4096、num_heads =32,num_layers=24输入一个维度为(5,128,4096)的向量进行文本解码,生成100个token,耗时对比如下:

chatGLM2中的Multi Query Attention_第6张图片

 生成100个token时,MQA解码平均耗时2.7826秒,MHA解码平均耗时6.4796秒,简单来看MQA在decoder解码加速了一倍。从模型结构来看原始的MHA一层5034W参数,而MQA只有1783W参数,还是通过压缩参数量来实现显存占用的减少以及推理时间的减少。

总结分析

显存占用和推理耗时减小是显而易见的,因为参数量减少了。至于效果变化得很小,只能说多头attention机制中的多头其实并不是一定,之前的bert模型有人探索了改变head头数目,也会保持效果变化不大。在大模型这,可能只需要有不同的head采用不同的query向量,kv一样来保证每个头提取到不同的特征就够了。

什么时候使用MQA有效呢?

1、采用attention的模型,模型规模越大,那么收益就约明显。

2、decoder生成任务相比较encoder任务收益明显大很大,其实decoder生成任务的收益来源于每一次softmax(QK)V注意力计算微小耗时差异的累积,一次生成任务要生成许多个token,一个token需要经历模型结构层数次的softmax(QK)V注意力的计算。

参考文章

Fast Transformer Decoding: One Write-Head is All You Need

ChatGLM2-6B

 huggingface / transformers

你可能感兴趣的:(自然语言处理,人工智能,大模型)