scaled_dot_product_attention实现逻辑

torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False, scale=None, enable_gqa=False) -> Tensor:

参数:

  • query (Tensor) – Query tensor; shape (batch_size,...,head_size,token_size,embeding_size)

  • key (Tensor) – Key tensor; shape (batch_size,...,head_size,token_size,embeding_size)

  • value (Tensor) – Value tensor; shape (batch_size,...,head_size,token_size,embeding_size)

  • attn_mask (optional Tensor) – Attention mask

  • dropout_p

你可能感兴趣的:(人工智能,深度学习,计算机视觉)