使用chatGPT编写的self-attention模块

import torch

# 计算两个向量的注意力分数
def attention_score(query, key):
  return torch.matmul(query, key.transpose(-2, -1))

# 计算注意力权重
def attention_weights(query, key, values):
  score = attention_score(query, key)
  weights = torch.softmax(score, dim=-1)
  return torch.matmul(weights, values)

# 构建自注意力层
class SelfAttention(torch.nn.Module):
  def __init__(self, input_size, hidden_size):
    super(SelfAttention, self).__init__()
    self.query = torch.nn.Linear(input_size, hidden_size)
    self.key = torch.nn.Linear(input_size, hidden_size)
    self.value = torch.nn.Linear(input_size, hidden_size)
    self.output = torch.nn.Linear(hidden_size, input_size)
  
  def forward(self, inputs):
    query = self.query(inputs)
    key = self.key(inputs)
    value = self.value(inputs)
    attention = attention_weights(query, key, value)
    return self.output(attention)

测试用例

# 定义测试用的输入
inputs = torch.randn(4, 5, 8)

# 创建一个自注意力层
attention = SelfAttention(8, 8)

# 计算自注意力层的输出
outputs = attention(inputs)

# 打印输出的形状
print(outputs.shape)  # 输出: (4, 5, 8)

在上面的示例代码中,输入张量的形状为 (4, 5, 8),其中 4、5、8 分别表示的含义如下:

4:批大小,即一次计算时处理的样本数量。在这个例子中,我们一次处理了 4 个样本。
5:序列长度,即每个样本中包含的元素数量。在这个例子中,每个样本由 5 个元素组成。
8:元素的维度,即每个元素的特征数量。在这个例子中,每个元素都由 8 个特征构成。
在实际应用中,这些维度的含义可能会有所不同。例如,在自然语言处理中,序列长度通常表示句子中的单词数量,元素维度表示单词的维度(例如词嵌入)。在图像处理中,序列长度可能表示图像中像素的数量,元素维度表示像素的通道数。总之,各个维度的含义取决于模型的具体应用场景。

你可能感兴趣的:(深度学习,chatgpt)