if is_flash_attn_available(): # 检查flashattention的可用性
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
FlashAttention是Tranformer模型中用于改进注意力机制的技术,主要目的是减少计算复杂度和内存占用。
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
创建了名为logger的日志记录器对象,__name__用于保存模块的名称,确保每个模块都有自己的日志记录器。
_CONFIG_FOR_DOC前面带有下划线,因此可以看出其代表一个模块的内部变量。
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
该模块的作用是padding_mask提取非填充的数据,分为以下几步:
torch.nonzero(padding_mask.flatten(), as_tuple=True)[0]
这样比较麻烦,不如直接返回二维数组再展平。
3. max_seqlen_in_batch获取了在seqlens_in_batch中的最大值并返回(即长度最长的那一个),然后 item()函数的作用是将一个元素的张量转换为python对应的标量,即一个数。
4. cu_seqlens计算累计长度并进行填充。cumsum()函数用于计算指定维度的累计和,(1,0)意味着只在左边添加一个元素,右边不添加。F.pad()是为张量进行填充的函数。这对于处理变长序列非常有用,因为即获得了每个序列的开始索引,容易确定起始和结束位置。
5. 最终返回的包括:非零元素的索引,左边填充过了的累计长度,最长序列的长度。
从而,达到了提取非填充数据的目的。
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
该模块用于生成因果掩码,通常用于双向自注意力机制。具体来说,该模块保证在计算注意力时,只能看到当前时间步之前的信息,而看不到未来的,来保持因果关系。
input_ids_shape: torch.Size:输入张量的形状,通常为(batch_size, target_length)。
dtype: torch.dtype:用于生成掩码的张量类型。
device: torch.device:指定设备是GPU还是CPU。
past_key_values_length: int = 0:过去的键值对长度,用于增量计算。