在多头注意力机制中,通常输入的数据包括查询(Q)、键(K)和值(V)。这些数据的维度以及权重矩阵的维度在多头注意力机制中扮演关键角色。下面对数据及权重的维度进行解释:
输入数据(Queries, Keys, Values):
q_dim
是查询向量的维度。key_dim
是键向量的维度。value_dim
是值向量的维度。权重矩阵:
num_heads
是注意力头的数量,head_dim
是每个注意力头的维度。def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
def stable_softmax(logits: jax.Array) -> jax.Array:
"""Numerically stable softmax for (potential) bfloat 16."""
if logits.dtype == jnp.float32:
output = jax.nn.softmax(logits)
elif logits.dtype == jnp.bfloat16:
# Need to explicitly do softmax in float32 to avoid numerical issues
# with large negatives. Large negatives can occur if trying to mask
# by adding on large negative logits so that things softmax to zero.
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
else:
raise ValueError(f'Unexpected input dtype {logits.dtype}')
return output
class Attention(hk.Module):
"""Multihead attention."""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
# weights维度(数据最后一维的维度数,注意力头数量,每个注意力头映射的数据维度)
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
# bqa: 输入张量 q_data 的轴的标记。(batch_size, seq_length, q_dim)
# 'b' :batch 维度,'q':查询序列维度,'a' 查询向量的维度。所以,'bqa' 表示 q_data 的三个轴。
# ahc:查询权重矩阵的形状, a:查询向量的维度,h:注意力头的数量,c: 每个注意力头中查询的维度。
# key_dim**(-0.5) 注意力缩放,避免注意力分数过大或过小
# jnp.einsum:Einstein Summation Notation(爱因斯坦求和约定)。
# 一种紧凑、灵活的方式来指定和计算张量的乘积、求和和转置等操作。
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
# 注意力分数,计算每个查询(q)和键(k)之间的点积,以获得注意力分数。
# 结果维度为bhqk (batch_size, num_heads, num_q, num_k),
# num_q/num_k为查询/键的数量,一般为 seq_length。
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
# 注意力分数中加入mask
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
# 对注意力分数进行softmax操作,我们得到每个位置对输入序列的权重分配。
weights = stable_softmax(logits)
# 注意力分数对值进行加权求和,得到多头注意力机制的输出
# 两个向量的点积可以用于度量它们之间的相似性。如果两个向量越相似,它们的点积就越大
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
# 带有bias的门控注意力
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
# ⊙ 对应元素相乘
weighted_avg *= gate_values
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
# 线性变换到输出维度大小
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
return output