Transformer的linear和softmax

线性层(Linear Layer)

场景

假设我们现在有一个包含许多特征的向量,比如描述一本书的内容、风格、作者、逻辑等信息。你想要根据这些特征预测这本书属于哪个类别,如小说、科幻、历史等。线性层的作用就是帮助你将这些特征转换成一个更简单的形式,使得你可以更容易地做出分类决策。

解释
  • 特征组合:线性层接收来自解码器最后一层的输出,这个输出是一个高维向量,包含了关于输入序列的丰富信息。

  • 权重矩阵:线性层内部有一组可学习的权重,它会乘以输入向量,并加上一个偏置项。这就像你在计算每个特征的重要性,然后给它们打分。

  • 简化输出:通过这种方式,线性层可以将高维向量压缩成一个较低维度的向量,通常与目标词汇表大小相同。这样,每个位置上的值就代表了对应词汇的可能性得分。

class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearLayer, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# 使用线性层将解码器的输出映射到词汇表大小
linear_layer = LinearLayer(d_model, trg_vocab_size)
logits = linear_layer(final_decoder_output)
类比理解

线性层就像是一个智能评分系统,它根据一系列特征来评估每个可能的结果(在这里是词汇)。通过调整权重,线性层能够学习哪些特征对最终结果更重要,从而做出更准确的预测。

  • 评分规则:权重矩阵就像是评分规则,决定了每个特征的重要性。比如,某个特征可能对预测动词特别有用,那么它的权重就会比较高。

  • 基础分:偏置向量则像是基础分,即使没有其他特征的支持,某些词汇也有一定的初始得分。

  • 综合评分:最后,线性层通过综合考虑所有特征及其对应的权重,给出每个词汇的得分。这些得分反映了每个词汇被选中的可能性。

线性层通过以下公式将输入特征转换成每个词汇的得分:

其中:

  • XX 是输入特征向量(解码器输出)。
  • WW 是权重矩阵。
  • bb 是偏置向量。

具体来说,线性层会做两件事:

  • 特征加权:根据权重矩阵 WW,线性层会评估每个输入特征的重要性。如果某个特征对于预测某个特定词汇非常重要,那么相应的权重就会较高;反之,则较低。
  • 得分汇总:然后,线性层会将所有加权后的特征相加,并加上偏置项 bb,得到每个词汇的原始得分(logits)。这些得分反映了每个词汇被选中的可能性。

Softmax 函数

场景设定

现在你有了一个评分列表,但是这些分数并不直观,因为你不知道哪个分数是最高的,也不知道它们之间的相对概率。Softmax函数的作用就是将这些分数转换成一个概率分布,使得每个分数都变成了一个介于0和1之间的概率值,并且所有概率加起来等于1。

解释
  • 概率转换:Softmax函数接收线性层输出的原始得分(logits),并将其转换为概率值。这使得我们可以明确知道每个词汇被选中的可能性。

  • 指数化:首先,Softmax会对每个得分进行指数化处理,确保所有的值都是正数,并放大差距较大的得分之间的差异。

  • 归一化:然后,它会将所有指数化的得分相加,得到一个总和。最后,每个指数化的得分除以这个总和,得到的概率值就会落在0到1之间,并且所有概率加起来等于1。

import torch.nn.functional as F

# 使用Softmax函数将线性层的输出转换为概率分布
probabilities = F.softmax(logits, dim=-1)
类比理解

Softmax函数就像是一个投票系统,它根据线性层给出的得分来决定每个候选词的当选概率。通过这种方式,我们不仅知道了哪个词最有可能被选择,还了解了其他词的选择概率。

总结

通过线性和Softmax这两个步骤,Transformer模型能够有效地将解码器的复杂输出转换成一个直观的概率分布,从而为生成新句子提供依据:

  1. 线性层:将解码器的高维输出简化为一个与词汇表大小相同的低维向量,每个位置上的值代表对应词汇的可能性得分。
  2. Softmax函数:将这些得分转换成概率值,使得我们可以明确知道每个词汇被选中的可能性。

这样一来,模型不仅能确定最有可能的下一个词,还能理解其他候选词的选择概率,从而更加灵活和智能地生成文本。

你可能感兴趣的:(AI,transformer,机器学习,人工智能)