InterLM代码解析

interLM的Transformer架构,重要模块的实现解析

Decoder架构


class InternLMDecoderLayer(nn.Module):
    def __init__(self, config: InternLMXComposerConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        if hasattr(config,
                   'intern_converted_llm') and config.intern_converted_llm:
            self.self_attn = InternConvertedInternLMAttention(config=config)
        else:
            self.self_attn = InternLMAttention(config=config)
        self.mlp = InternLMMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            config=config,
        )
        self.input_layernorm = InternLMRMSNorm(config.hidden_size,
                                               eps=config.rms_norm_eps)
        self.post_attention_layernorm = InternLMRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
                                                 torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states, )

        if output_attentions:
            outputs += (self_attn_weights, )

        if use_cache:
            outputs += (present_key_value, )

        return outputs

MLP

  • 两个MLP层+一个门控激活函数
class InternLMMLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int,
                 hidden_act: str, config: InternLMXComposerConfig):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        if config.lora_cfg is not None and 'ffn' in config.lora_cfg[
                'learn_param']:
            lora_cfg = config.lora_cfg
            self.down_proj = LoRALinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
                                        **lora_cfg)
            self.up_proj = LoRALinear(hidden_size,
                                      intermediate_size,
                                      bias=False,
                                      **lora_cfg)
        else:
            self.down_proj = nn.Linear(intermediate_size,
                                       hidden_size,
                                       bias=False)
            self.up_proj = nn.Linear(hidden_size,
                                     intermediate_size,
                                     bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

CausalAttention Mask

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
                      dtype: torch.dtype,
                      device: torch.device,
                      past_key_values_length: int = 0):
    """
    Make causal mask used for bi-directional self-attention.
    """
    # 获取输入的形状,包括批量大小和目标长度
    bsz, tgt_len = input_ids_shape
    # 初始化一个形状为(目标长度, 目标长度)的tensor,用极小值填充. 即mask矩阵
    mask = torch.full((tgt_len, tgt_len),
                      torch.tensor(torch.finfo(dtype).min, device=device),
                      device=device)
    # 创建一个mask_cond张量,其范围是[0, tgt_len-1]
    mask_cond = torch.arange(mask.size(-1), device=device)
    # 根据条件进行填充,下三角为0,上三角为1
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    # 转换mask的数据类型为dtype
    mask = mask.to(dtype)

    # 如果过去键值的长度大于0,则将其拼接到mask的前面
    if past_key_values_length > 0:
        mask = torch.cat([
            torch.zeros(
                tgt_len, past_key_values_length, dtype=dtype, device=device),
            mask
        ],
                         dim=-1)
    # 返回形状为[bsz, 1, tgt_len, tgt_len + past_key_values_length]的mask
    return mask[None, None, :, :].expand(bsz, 1, tgt_len,
                                         tgt_len + past_key_values_length)

past_key_values_length

在Transformer中,past_key_values_length是指用于存储前一次计算的注意力键值对(key-value pairs)的长度。Transformer模型在处理较长的序列时,为了提高效率会使用存储,以避免重复计算。

  • 当输入序列长度增加时,前一次的键值对会被缓存以供后续的注意力计算使用。这样可以节省计算时间,特别是在生成式任务中,如机器翻译或文本生成。

  • 为什么用zeros?
    如果past_key_values_length大于0,即存在过去的键值对需要存储,我们需要将这些过去的键值对所对应的掩码(mask)拼接到当前的掩码中。

在这里,我们首先创建了一个与当前mask形状相同的全零张量,用于表示过去的掩码。然后,通过使用torch.cat函数将这个全零张量和当前的mask进行拼接,以便将过去的信息与当前的信息合并在一起,形成一个更大的掩码张量。

详细解释一下如何创建CasualMask矩阵

当调用masked_fill_函数时,我们传入了一个条件(mask_cond < (mask_cond + 1).view(mask.size(-1), 1))和一个填充值(0)。

这个条件 mask_cond < (mask_cond + 1).view(mask.size(-1), 1) 创建了一个下三角为True,上三角为False的条件掩码。

当我们执行 (mask_cond + 1).view(mask.size(-1), 1) 时,我们将 mask_cond 中的每个元素增加 1,并且重新塑造成一个列向量。假设 mask_cond 最初是一个长度为 4 的向量 [0, 1, 2, 3],那么在执行 +1view 操作后得到的列向量就是:

[1]
[2]
[3]
[4]

现在,我们比较 mask_cond(mask_cond+1).view(mask.size(-1), 1)。我们发现,如果 mask_cond 中的元素小于列向量中对应位置的元素,这意味着该位置处于下三角区域。例如,在这个例子中,当我们比较原始向量和列向量时:

[0, 1, 2, 3]   <   [1]
[1, 2, 3, 4]   <   [2]
[2, 3, 4, 5]   <   [3]
[3, 4, 5, 6]   <   [4]

这将生成一个下三角为 True,上三角为 False 的布尔掩码,可以用于创建Mask。

masked_fill_函数用条件掩码来填充张量。在这里,如果条件为True,对应位置将被填充为0。这样就实现了对角线以下的元素被填充为0,对角线以上的元素保持不变。


Attention Mask

def _expand_mask(mask: torch.Tensor,
                 dtype: torch.dtype,
                 tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    # 如果未提供目标序列长度,默认使用源序列的长度
    tgt_len = tgt_len if tgt_len is not None else src_len

    # 对输入的掩码进行扩展
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    # 创建一个反转的掩码
    inverted_mask = 1.0 - expanded_mask

    # 使用反转的掩码来填充掩码张量中的元素
    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  • 使用反转的掩码来填充掩码张量中的元素的目的是将掩码中原本为0的位置填充为负无穷小。

在注意力计算中,当掩码中某个位置的元素为负无穷小时,经过softmax计算后,该位置对应的注意力权重会趋近于0,即忽略该位置的信息。这样做的目的是,在计算注意力时,我们希望掩码的位置能够有效地抑制相关位置的注意力权重,从而确保模型在处理序列时不会受到未来信息的影响,比如在解码阶段不会看到未来时刻的标记。因此,使用反转的掩码来填充掩码张量中的元素是为了在注意力计算中实现对未来信息的屏蔽。


RoPE

class InternLMRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # 计算频率,根据RoPE公式 1.0 / (base **(2 * (i // 2) / dim))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)  # 将频率注册为缓冲张量

        # 构建sin和cos缓存
        self.max_seq_len_cached = max_position_embeddings
        # t是位置索引
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # 通过张量乘法计算频率
        emb = torch.cat((freqs, freqs), dim=-1)  # 按照最后一个维度拼接sin和cos
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)  # 将cos缓存注册为缓冲张量
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)  # 将sin缓存注册为缓冲张量

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 这个if块不太可能在构建sin/cos后运行。保持逻辑在这里以防万一。
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # 通过张量乘法计算频率
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)  # 按照最后一个维度拼接sin和cos
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)  # 更新注册cos缓存
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)  # 更新注册sin缓存
        # 返回缓存中的sin和cos张量,截取到指定的序列长度
        return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype))
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    # 将输入张量沿最后一个维度分成两部分,执行旋转操作
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    # 拼接结果返回
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """
    Applies rotary positional embeddings to input queries and keys.

    Args:
    q: 输入的查询张量
    k: 输入的键张量
    cos: cos缓存张量
    sin: sin缓存张量
    position_ids: 位置编码张量

    Returns:
    q_embed: 应用了旋转位置嵌入后的查询张量
    k_embed: 应用了旋转位置嵌入后的键张量
    """
    # 根据position_ids创建索引张量
    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
    # 通过gather_indices选择对应的cos和sin张量
    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    # 应用旋转位置嵌入公式得到新的查询张量和键张量
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

torch.gather函数的参数包括:

  1. input:这是输入张量,从这个张量中收集值。
  2. dim:这是一个整数值,表示在input张量中收集数据的维度。
  3. index:这是包含了索引的张量。根据这些索引,函数将从input张量中收集对应的值。

基本语法为:torch.gather(input, dim, index)

cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)

可以分解为以下几个步骤:

  1. cos.repeat(gather_indices.shape[0], 1, 1, 1): 这一步是将cos张量沿着每个维度进行复制以匹配gather_indices的形状。repeat函数会根据指定的次数沿着各个维度对原始张量进行复制。在这里,它会根据gather_indices.shape[0]的值在第一个维度上进行复制,而不在其他维度进行复制。

  2. torch.gather(repeated_cos, 2, gather_indices): 紧接着,我们使用torch.gather函数根据gather_indices中指定的索引从repeated_cos中收集对应的值。对于序列中的每个位置,gather_indices指定了从repeated_cos张量中选择哪个值。

torch.gather操作主要用于根据索引张量从源张量中收集对应的值。通过上述操作,我们能够根据gather_indices为序列中的每个位置选择正确的cos值,并将其应用于后续的计算中。这是PyTorch中的常见技术,用于根据索引张量从张量中提取值。

LoRA

  • 有意思的是,对LoRA做了改动
  • 有点残差连接和RoPE的思想糅合到一起的操作
    • x += res
    • 中间断开,奇偶分开

class ConvertedLoRALinear(nn.Linear):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 lora_r=8,
                 lora_alpha=16,
                 lora_dropout=0.05,
                 **kwargs) -> None:
        super().__init__(in_features, out_features, bias, device, dtype)
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.lora_scaling = self.lora_alpha / self.lora_r

        self.lora_A = nn.Linear(in_features,
                                self.lora_r,
                                bias=False,
                                device=device,
                                dtype=dtype)
        self.lora_B = nn.Linear(self.lora_r,
                                out_features,
                                bias=False,
                                device=device,
                                dtype=dtype)

        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B.weight)
            # print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight)))

    def forward(self, x):
        orig_type = x.dtype
        res = super().forward(x)

        dim = int(res.shape[-1] // 2)

        r1 = res[..., :dim]
        r2 = res[..., dim:]

        r1 = r1.float()
        r2 = r2.float()
        x_ = x.float()

        tmp = self.lora_B(self.lora_A(
            self.lora_dropout(x_))) * self.lora_scaling
        tmp1 = tmp[..., ::2]
        tmp2 = tmp[..., 1::2]

        r1 += tmp1
        r2 += tmp2

        r1 = r1.to(orig_type)
        r2 = r2.to(orig_type)

        res = torch.cat([r1, r2], -1)

        # res += self.lora_B(self.lora_A(
        #     self.lora_dropout(x))) * self.lora_scaling
        return res

关于生成是模型的Loss计算

outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            query_embeds=query_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens

            loss_fct = CrossEntropyLoss(reduce=False)
            loss_reduce = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            ###
            if self.sp_id >= 0:
                ori_mask = (shift_labels != self.sp_id).float()
                ori_mask = ori_mask * (shift_labels >= 0).float()
                local_mask = (shift_labels == self.sp_id).float()
            else:
                ori_mask = (shift_labels <
                            self.config.vocab_size - self.ex_size).float()
                ori_mask = ori_mask * (shift_labels >= 0).float()
                local_mask = (shift_labels >=
                              self.config.vocab_size - self.ex_size).float()

            # Enable model parallelism

            loss = loss_reduce(shift_logits, shift_labels)

            loss_all = loss_fct(shift_logits, shift_labels)
            loss_o = (loss_all * ori_mask).sum() / ori_mask.sum()
            if torch.sum(local_mask) == 0:
                loss_l = loss_o * 0
            else:
                loss_l = (loss_all * local_mask).sum() / local_mask.sum()

代码中loss计算的逐步解释:

1. 首先检查是否有标签(labels),如果有则继续计算loss,否则将loss保持为None。

2. 在标签存在的情况下,对logits进行了一个向左的位移,这是因为模型中的输入数据和输出标签之间需要进行一定的位移。即把logits中的每个位置的预测,对应到相应位置期待的标签。

3. 之后对logits和labels进行view操作,将其形状改变为2D的张量,以便进行交叉熵损失的计算。

4. 根据self.sp_id的不同取值,计算了ori_mask和local_mask。ori_mask为了确保不计算特殊token(sp_id)的loss,local_mask则是用于计算特殊token(sp_id)的loss。

5. 调用`CrossEntropyLoss`设置了两个不同的loss,loss_reduce用于在整个批次上计算损失,loss_fct则是用于对每个位置的损失值进行计算。

6. 最后,计算了不同的部分的损失。loss_o计算了非特殊token的损失,而loss_l计算了特殊token的损失。如果local_mask全为0,则loss_l为0.

总结:该段代码进行了交叉熵损失的计算,但根据输入token是否为特殊token(sp_id),它分别计算了不同的loss值,即ori_mask用于过滤掉特殊token本身的loss,local_mask用于计算特殊token的loss。

  • 这个loss的计算实际上是基于给定的vocabulary的多分类交叉熵损失。

在语言模型中,通常需要将模型的输出与词汇表中的token进行比较,以根据模型的预测计算损失。因此,将模型输出的logits与标签进行比较,并计算交叉熵损失,这通常用于语言模型中的训练过程。

你可能感兴趣的:(人工智能,深度学习)