transformer之Attention机制及代码实现

目录

    • 异同点总结
    • 代码实现
      • Self-Attention
      • Cross-Attention
      • Gated Self-Attention
      • Generalized Query Attention
    • PagedAttention

Self-Attention:一种Attention机制,用于处理单个输入序列中的依赖关系。
Cross-Attention:一种Attention机制,用于处理两个或多个输入序列之间的依赖关系。
Gated Self-Attention:一种改进的Self-Attention机制,引入了门控机制来控制Attention输出。
Generalized Query Attention:一种扩展的Self-Attention机制,支持多个Query和多个Key-Value对。
这些Attention机制都可以用于自然语言处理、计算机视觉等领域,用于捕获输入数据中的依赖关系和语义信息。

以下是Attention机制的异同点表格,输出为Markdown格式:

Attention机制 Self-Attention Cross-Attention Gated Self-Attention Generalized Query Attention
输入 单个输入序列 两个或多个输入序列 单个输入序列 多个Query和多个Key-Value对
输出 Attention输出 Attention输出 Attention输出 Attention输出
依赖关系 单个输入序列中的依赖关系 两个或多个输入序列之间的依赖关系 单个输入序列中的依赖关系 多个Query和多个Key-Value对之间的依赖关系
门控机制
支持多个Query
支持多个Key-Value对

异同点总结

  • Self-Attention和Gated Self-Attention都用于处理单个输入序列中的依赖关系,但Gated Self-Attention引入了门控机制来控制Attention输出。
  • Cross-Attention用于处理两个或多个输入序列之间的依赖关系。
  • Generalized Query Attention支持多个Query和多个Key-Value对,用于处理更复杂的依赖关系。

代码实现

Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, hidden_size, attention_heads):
        super(SelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attention_heads = attention_heads
        self.query_linear = nn.Linear(hidden_size, hidden_size)
        self.key_linear = nn.Linear(hidden_size, hidden_size)
        self.value_linear = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x: [batch_size, sequence_length, hidden_size]
        batch_size, sequence_length, _ = x.size()

        # Linear transformations
        query = self.query_linear(x)
        key = self.key_linear(x)
        value = self.value_linear(x)

        # Attention weights
        attention_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.hidden_size)
        attention_weights = F.softmax(attention_weights, dim=-1)

        # Attention output
        attention_output 

你可能感兴趣的:(NLP,AIGC,transformer,attention,LLM)