paddle 版本多头注意力没有参数 key_padding_mask
这在 deformable DETR 中很有用,似乎可以用 attn_mask
实现,但是为了与 Torch 对齐,目前 copy Torch 版本改成了 Paddle 版,大部分数是对齐的,除了一少部分,文章结尾部分有关于参数对齐验证的介绍
该版本的 attn_mask 参数并未对齐,甚至依旧是 import torch
,因为在 Paddle 版本的 MultiHeadAttention 中有 attn_mask 参数
from typing import Callable, List, Optional, Tuple
import math
import paddle
import paddle.nn as nn
from paddle.nn import Linear
from paddle.nn.functional import linear, dropout, softmax
import warnings
import numpy as np
from paddle.jit import to_static
from initializer import linear_init_, constant_, xavier_uniform_, normal_, xavier_normal_
# @to_static
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask.astype(bool), y, x)
# def _pad(input: paddle.Tensor, pad: List[int], mode: str = "constant", value: float = 0.0) -> paddle.Tensor:
# r"""Pads tensor.
# Padding size:
# The padding size by which to pad some dimensions of :attr:`input`
# are described starting from the last dimension and moving forward.
# :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
# of ``input`` will be padded.
# For example, to pad only the last dimension of the input tensor, then
# :attr:`pad` has the form
# :math:`(\text{padding\_left}, \text{padding\_right})`;
# to pad the last 2 dimensions of the input tensor, then use
# :math:`(\text{padding\_left}, \text{padding\_right},`
# :math:`\text{padding\_top}, \text{padding\_bottom})`;
# to pad the last 3 dimensions, use
# :math:`(\text{padding\_left}, \text{padding\_right},`
# :math:`\text{padding\_top}, \text{padding\_bottom}`
# :math:`\text{padding\_front}, \text{padding\_back})`.
# Padding mode:
# See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
# :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
# padding modes works. Constant padding is implemented for arbitrary dimensions.
# Replicate and reflection padding is implemented for padding the last 3
# dimensions of 5D input tensor, or the last 2 dimensions of 4D input
# tensor, or the last dimension of 3D input tensor.
# Note:
# When using the CUDA backend, this operation may induce nondeterministic
# behaviour in its backward pass that is not easily switched off.
# Please see the notes on :doc:`/notes/randomness` for background.
# Args:
# input (paddle.Tensor): N-dimensional tensor
# pad (tuple): m-elements tuple, where
# :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
# mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
# Default: ``'constant'``
# value: fill value for ``'constant'`` padding. Default: ``0``
# Examples::
# >>> t4d = torch.empty(3, 3, 4, 2)
# >>> p1d = (1, 1) # pad last dim by 1 on each side
# >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
# >>> print(out.shape)
# torch.Size([3, 3, 4, 4])
# >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
# >>> out = F.pad(t4d, p2d, "constant", 0)
# >>> print(out.shape)
# torch.Size([3, 3, 8, 4])
# >>> t4d = torch.empty(3, 3, 4, 2)
# >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
# >>> out = F.pad(t4d, p3d, "constant", 0)
# >>> print(out.shape)
# torch.Size([3, 9, 7, 3])
# """
# if has_torch_function_unary(input):
# return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value)
# assert len(pad) % 2 == 0, "Padding length must be divisible by 2"
# assert len(pad) // 2 <= input.dim(), "Padding length too large"
# if mode == "constant":
# return _VF.constant_pad_nd(input, pad, value)
# else:
# assert value == 0.0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
# if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
# if mode == "reflect":
# return torch._C._nn.reflection_pad1d(input, pad)
# elif mode == "replicate":
# return torch._C._nn.replication_pad1d(input, pad)
# elif mode == "circular":
# return _pad_circular(input, pad)
# else:
# raise NotImplementedError
# elif len(pad) == 4 and (input.dim() == 3 or input.dim() == 4):
# if mode == "reflect":
# return torch._C._nn.reflection_pad2d(input, pad)
# elif mode == "replicate":
# return torch._C._nn.replication_pad2d(input, pad)
# elif mode == "circular":
# return _pad_circular(input, pad)
# else:
# raise NotImplementedError
# elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):
# if mode == "reflect":
# return torch._C._nn.reflection_pad3d(input, pad)
# elif mode == "replicate":
# return torch._C._nn.replication_pad3d(input, pad)
# elif mode == "circular":
# return _pad_circular(input, pad)
# else:
# raise NotImplementedError
# else:
# raise NotImplementedError("Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now")
# # We define this function as _pad because it takes an argument
# # named pad, which clobbers the recursive reference to the pad
# # function needed for __torch_function__ support
# pad = _pad
def _in_projection(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
w_q: paddle.Tensor,
w_k: paddle.Tensor,
w_v: paddle.Tensor,
b_q: Optional[paddle.Tensor] = None,
b_k: Optional[paddle.Tensor] = None,
b_v: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
r"""
Performs the in-projection step of the attention operation. This is simply
a triple of linear projections, with shape constraints on the weights which
ensure embedding dimension uniformity in the projected outputs.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected.
w_q, w_k, w_v: weights for q, k and v, respectively.
b_q, b_k, b_v: optional biases for q, k and v, respectively.
Shape:
Inputs:
- q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
number of leading dimensions.
- k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
number of leading dimensions.
- v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
number of leading dimensions.
- w_q: :math:`(Eq, Eq)`
- w_k: :math:`(Eq, Ek)`
- w_v: :math:`(Eq, Ev)`
- b_q: :math:`(Eq)`
- b_k: :math:`(Eq)`
- b_v: :math:`(Eq)`
Output: in output triple :math:`(q', k', v')`,
- q': :math:`[Qdims..., Eq]`
- k': :math:`[Kdims..., Eq]`
- v': :math:`[Vdims..., Eq]`
"""
Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
def _scaled_dot_product_attention(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
attn_mask: Optional[paddle.Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
r"""
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.
Returns a tensor pair containing attended values and attention weights.
Args:
q, k, v: query, key and value tensors. See Shape section for shape details.
attn_mask: optional tensor containing mask values to be added to calculated
attention. May be 2D or 3D; see Shape section for details.
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
Shape:
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
and E is embedding dimension.
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
and E is embedding dimension.
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
and E is embedding dimension.
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
shape :math:`(Nt, Ns)`.
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
have shape :math:`(B, Nt, Ns)`
"""
B, Nt, E = q.shape
q = q / math.sqrt(E)
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
# attn = torch.bmm(q, k.transpose(-2, -1))
attn = paddle.bmm(q, k.transpose([0, 2, 1]))
if attn_mask is not None:
attn += attn_mask
attn = softmax(attn, axis=-1)
if dropout_p > 0.0:
attn = dropout(attn, p=dropout_p)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
# output = torch.bmm(attn, v)
output = paddle.bmm(attn, v)
return output, attn
def _in_projection_packed(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
w: paddle.Tensor,
b: Optional[paddle.Tensor] = None,
) -> List[paddle.Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected. For self-attention,
these are typically the same tensor; for encoder-decoder attention,
k and v are typically the same tensor. (We take advantage of these
identities for performance if they are present.) Regardless, q, k and v
must share a common embedding dimension; otherwise their shapes may vary.
w: projection weights for q, k and v, packed into a single tensor. Weights
are packed along dimension 0, in q, k, v order.
b: optional projection biases for q, k and v, packed into a single tensor
in q, k, v order.
Shape:
Inputs:
- q: :math:`(..., E)` where E is the embedding dimension
- k: :math:`(..., E)` where E is the embedding dimension
- v: :math:`(..., E)` where E is the embedding dimension
- w: :math:`(E * 3, E)` where E is the embedding dimension
- b: :math:`E * 3` where E is the embedding dimension
Output:
- in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor.
"""
E = q.shape[-1]
if k is v:
pass
# if q is k:
# # self-attention
# return linear(q, w.T, b).chunk(3, axis=-1)
# else:
# # encoder-decoder attention
# w_q, w_kv = w.split([E, E * 2])
# if b is None:
# b_q = b_kv = None
# else:
# b_q, b_kv = b.split([E, E * 2])
# return (linear(q, w_q.T, b_q),) + linear(k, w_kv.T, b_kv).chunk(2, axis=-1)
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
# # This class exists solely to avoid triggering an obscure error when scripting
# # an improperly quantized attention layer. See this issue for details:
# # https://github.com/pytorch/pytorch/issues/58969
# # TODO: fail fast on quantization API usage error, then remove this class
# # and replace uses of it with plain Linear
# class NonDynamicallyQuantizableLinear(Linear):
# def __init__(self, in_features: int, out_features: int, bias: bool = True,
# device=None, dtype=None) -> None:
# super().__init__(in_features, out_features, bias=bias,
# device=device, dtype=dtype)
NonDynamicallyQuantizableLinear = Linear
def multi_head_attention_forward(
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: paddle.Tensor,
in_proj_bias: Optional[paddle.Tensor],
bias_k: Optional[paddle.Tensor],
bias_v: Optional[paddle.Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: paddle.Tensor,
out_proj_bias: Optional[paddle.Tensor],
training: bool = True,
key_padding_mask: Optional[paddle.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[paddle.Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[paddle.Tensor] = None,
k_proj_weight: Optional[paddle.Tensor] = None,
v_proj_weight: Optional[paddle.Tensor] = None,
static_k: Optional[paddle.Tensor] = None,
static_v: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at axis=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at axis=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
# tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
# if has_torch_function(tens_ops):
# return handle_torch_function(
# multi_head_attention_forward,
# tens_ops,
# query,
# key,
# value,
# embed_dim_to_check,
# num_heads,
# in_proj_weight,
# in_proj_bias,
# bias_k,
# bias_v,
# add_zero_attn,
# dropout_p,
# out_proj_weight,
# out_proj_bias,
# training=training,
# key_padding_mask=key_padding_mask,
# need_weights=need_weights,
# attn_mask=attn_mask,
# use_separate_proj_weight=use_separate_proj_weight,
# q_proj_weight=q_proj_weight,
# k_proj_weight=k_proj_weight,
# v_proj_weight=v_proj_weight,
# static_k=static_k,
# static_v=static_v,
# )
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, paddle.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight: # True
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
pass
# assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
# assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
# assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
# if in_proj_bias is None:
# b_q = b_k = b_v = None
# else:
# b_q, b_k, b_v = in_proj_bias.chunk(3)
# q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# prep attention mask
if attn_mask is not None: # False
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
else:
assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# prep key padding mask
if key_padding_mask is not None and key_padding_mask.dtype == paddle.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(paddle.bool)
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None: # False
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)])
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
# q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
if static_k is None: # True
# k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.shape[0] == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
assert static_k.shape[2] == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
k = static_k
if static_v is None: # True
# v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.shape[0] == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
assert static_v.shape[2] == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
v = static_v
# add zero attention along batch dimension (now first)
# if add_zero_attn: # False
# zero_attn_shape = (bsz * num_heads, 1, head_dim)
# k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
# v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
# if attn_mask is not None:
# attn_mask = pad(attn_mask, (0, 1))
# if key_padding_mask is not None:
# key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.shape[1]
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == [bsz, src_len], \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]). \
expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
if attn_mask is None: # True
attn_mask = key_padding_mask
# # Notice: TODO!
# attn_mask = 1 - key_padding_mask
elif attn_mask.dtype == paddle.bool:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
# convert mask to float
# if attn_mask is not None and attn_mask.dtype == paddle.bool:
if attn_mask is not None and \
(attn_mask.max().item() < 1 + 10e-5) and \
(attn_mask.min().item() > 0 - 10e-5) and \
len(paddle.unique(attn_mask.flatten())) == 2:
# new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
new_attn_mask = paddle.zeros_like(attn_mask, dtype=paddle.float32)
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
new_attn_mask = masked_fill(new_attn_mask, attn_mask, float("-inf"))
attn_mask = new_attn_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
attn_output = linear(attn_output, out_proj_weight.T, out_proj_bias)
if need_weights:
# average attention weights over heads
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
return attn_output, attn_output_weights.sum(axis=1) / num_heads
else:
return attn_output, None
class MultiheadAttention(nn.Layer):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See `Attention Is All You Need `_.
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at axis=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at axis=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ['batch_first']
bias_k: Optional[paddle.Tensor]
bias_v: Optional[paddle.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
# factory_kwargs = {'device': device, 'dtype': dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
pass
# self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
# self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
# self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
# self.register_parameter('in_proj_weight', None)
else:
# self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
# self.register_parameter('q_proj_weight', None)
# self.register_parameter('k_proj_weight', None)
# self.register_parameter('v_proj_weight', None)
if dtype is None:
dtype = paddle.float32
self.in_proj_weight = paddle.create_parameter((3 * embed_dim, embed_dim), dtype)
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
if bias:
# self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
self.in_proj_bias = paddle.create_parameter((3 * embed_dim, ), dtype)
constant_(self.in_proj_bias)
else:
# self.register_parameter('in_proj_bias', None)
self.in_proj_bias = None
# self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias_attr=bias)
if add_bias_kv:
# self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
# self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
pass
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias)
constant_(self.out_proj.bias)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
constant_(self.in_proj_weight, 1)
constant_(self.in_proj_bias, 2)
constant_(self.out_proj.weight, 3)
constant_(self.out_proj.bias, 4)
# import numpy as np
# np.random.seed(1107)
# in_proj_weight = paddle.to_tensor(np.random.rand(*list(self.in_proj_weight.shape)).astype("float32"))
# self.in_proj_weight.set_value(in_proj_weight)
# in_proj_bias = paddle.to_tensor(np.random.rand(*list(self.in_proj_bias.shape)).astype("float32"))
# self.in_proj_bias.set_value(in_proj_bias)
# out_proj_weight = paddle.to_tensor(np.random.rand(*list(self.out_proj.weight.shape)).astype("float32"))
# self.out_proj.weight.set_value(out_proj_weight)
# out_proj_bias = paddle.to_tensor(np.random.rand(*list(self.out_proj.bias.shape)).astype("float32"))
# self.out_proj.bias.set_value(out_proj_bias)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
key_padding_mask: Optional[paddle.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[paddle.Tensor] = None) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)`
when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size,
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
value will be ignored.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or
:math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch
size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned
when ``need_weights=True``.
"""
if self.batch_first:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
else:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
if self.batch_first:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
if __name__ == "__main__":
import numpy as np
import paddle
np.random.seed(1107)
q = np.random.randn(*[80, 2, 256]).astype("float32")
k = np.random.randn(*[19672, 2, 256]).astype("float32")
v = np.random.randn(*[19672, 2, 256]).astype("float32")
msk = np.random.randint(0, 2, size=[2, 19672]).astype("float32")
q = paddle.to_tensor(q)
k = paddle.to_tensor(k)
v = paddle.to_tensor(v)
msk = paddle.to_tensor(msk)
model = MultiheadAttention(256, 8, 0.0)
out, attn = model(q, k, v, key_padding_mask=msk)
print()
这是验证 torch 的代码:
import torch
import numpy as np
np.random.seed(1107)
q = np.random.randn(*[80, 2, 256]).astype("float32")
k = np.random.randn(*[19672, 2, 256]).astype("float32")
v = np.random.randn(*[19672, 2, 256]).astype("float32")
msk = np.random.randint(0, 2, size=[2, 19672]).astype("float32")
q = torch.tensor(q)
k = torch.tensor(k)
v = torch.tensor(v)
msk = torch.tensor(msk)
model = torch.nn.MultiheadAttention(256, 8, 0.0)
out, attn = model(q, k, v, key_padding_mask=msk)
在 MultiheadAttention 的参数初始化过程 _reset_parameters
函数中,添加这么几句:
import numpy as np
np.random.seed(1107)
in_proj_weight = torch.Tensor(np.random.rand(*list(self.in_proj_weight.shape)).astype("float32"))
self.in_proj_weight = torch.nn.Parameter(in_proj_weight)
in_proj_bias = torch.Tensor(np.random.rand(*list(self.in_proj_bias.shape)).astype("float32"))
self.in_proj_bias = torch.nn.Parameter(in_proj_bias)
out_proj_weight = torch.Tensor(np.random.rand(*list(self.out_proj.weight.shape)).astype("float32"))
self.out_proj.weight = torch.nn.Parameter(out_proj_weight)
out_proj_bias = torch.Tensor(np.random.rand(*list(self.out_proj.bias.shape)).astype("float32"))
self.out_proj.bias = torch.nn.Parameter(out_proj_bias)
用于模型参数一致,而 Paddle 中则添加这么几句:
import numpy as np
np.random.seed(1107)
in_proj_weight = paddle.to_tensor(np.random.rand(*list(self.in_proj_weight.shape)).astype("float32"))
self.in_proj_weight.set_value(in_proj_weight)
in_proj_bias = paddle.to_tensor(np.random.rand(*list(self.in_proj_bias.shape)).astype("float32"))
self.in_proj_bias.set_value(in_proj_bias)
out_proj_weight = paddle.to_tensor(np.random.rand(*list(self.out_proj.weight.shape)).astype("float32"))
self.out_proj.weight.set_value(out_proj_weight)
out_proj_bias = paddle.to_tensor(np.random.rand(*list(self.out_proj.bias.shape)).astype("float32"))
self.out_proj.bias.set_value(out_proj_bias)