PyTorch笔记 - Attention Is All You Need (3)

Transformer难点细节实现,6点:Word Embedding、Position Embedding、Encoder Self-Attention MaskIntra-Attention MaskDecoder Self-Attention MaskMulti-Head Self-Attention

Encoder Self-Attention Mask

参考:Transformer Architecture: The Positional Encoding

PyTorch笔记 - Attention Is All You Need (3)_第1张图片

PyTorch笔记 - Attention Is All You Need (3)_第2张图片

PyTorch笔记 - Attention Is All You Need (3)_第3张图片

PyTorch笔记 - Attention Is All You Need (3)_第4张图片

Intra-Attention Mask

# 2022.8.2
# Step 5,构造Intra-Attention Mask,用于Decoder
# Q @ K^T shape [batch_size, tgt_seq_len, src_seq_len]
valid_encoder_pos = torch.stack([F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len])
valid_encoder_pos = torch.unsqueeze(valid_encoder_pos, dim=2)
print(f'[Info] valid_encoder_pos.shape: {valid_encoder_pos.shape}')
print(f'[Info] valid_encoder_pos: {valid_encoder_pos}')  # 有效位置是1,无效位置是0,根据batch的最大长度

valid_decoder_pos = torch.stack([F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len])
valid_decoder_pos = torch.unsqueeze(valid_decoder_pos, dim=2)
print(f'[Info] valid_encoder_pos.shape: {valid_decoder_pos.shape}')
print(f'[Info] valid_encoder_pos: {valid_decoder_pos}')  # 有效位置是1,无效位置是0,根据batch的最大长度

# 源序列和目标序列的相关性,相关是1,不相关是0,bmm就是batch的矩阵相乘
# decoder * encoder^T,Decoder是Q,Encoder是K、V
valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
print(f'[Info] valid_cross_pos_matrix.shape: {valid_cross_pos_matrix.shape}')
print(f'[Info] valid_cross_pos_matrix: {valid_cross_pos_matrix}')

invalid_cross_pos_matrx = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrx.to(torch.bool)
print(f"mask_cross_attention: \n{mask_cross_attention}")

Decoder是Q,Encoder是K、V

image-20220804091738789

Decoder Self-Attention Mask

# Step 6:构造Decoder Self-Attention Mask
# tri代表三角形,l是low,u是up,上三角和下三角
# Transformer用在流式,都会使用因果的列表
# pad:左、右、上、下
valid_decoder_tri_matrix = [F.pad(torch.tril(torch.ones((L, L))), (0, max(tgt_len)-L, 0, max(tgt_len)-L)) for L in tgt_len]
valid_decoder_tri_matrix = torch.stack(valid_decoder_tri_matrix, dim=0)
print(f"[Info] valid_decoder_tri_matrix: \n{valid_decoder_tri_matrix}")
print(f"[Info] valid_decoder_tri_matrix.shape: \n{valid_decoder_tri_matrix.shape}")
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
print(f"[Info] invalid_decoder_tri_matrix: \n{invalid_decoder_tri_matrix}")

# 测试
score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
prob = F.softmax(masked_score, -1)
print(f"tgt_len: {tgt_len}")
print(f"prob: \n{prob}")

Multi-Head Self-Attention

scaled self-attention:

# Step7 构建scaled self-attention
def scaled_dot_product_attention(Q, K, V, attn_mask):
    # shape of Q, K, V: (batch_size*num_head, seq_len, model_dim/num_head)
    score = torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(model_dim)
    masked_score = score.masked_fill(attn_mask, -1e9)
    prob = F.softmax(masked_score, -1)
    context = torch.bmm(prob, V)
    return context

Transformer源码:torch.nn.modules.transformer.py

forward输入:

  • tgt_maskDecoder Self-Attention Mask
  • memory_maskIntra-Attention Mask

核心逻辑:F.multi_head_attention_forward

你可能感兴趣的:(深度学习,pytorch,深度学习,人工智能)