PyTorch笔记 - Seq2Seq + Attention 源码

Encoder:编码器,将序列建模为上下文相关的表征,输入:

  • embedding_dim
  • hidden_size
  • embedding_table,转换,vocab id到embedding向量

Seq2SeqAttentionMechanism:Attention机制,输入,t时刻的解码器状态、encoder的全部states,只是操作,不需要学习

Decoder:Encoder是lstm_layer,Decoder是LSTMCell

import torch
import torch.nn as nn
import torch.nn.functional as F

"""
以离散符号的分类任务为例,实现基于注意力机制的seq2seq模型
"""

class Seq2SeqEncoder(nn.Module):
    """
    实现基于LSTM的编码器,也可以是其他类型的,如CNN、Transformer-Encoder
    """
    def __init__(self, embedding_dim, hidden_size, source_vocab_size):
        super(Seq2SeqEncoder, self).__init__()
        self.lstm_layer = nn.LSTM(
            input_size=embedding_dim, 
            hidden_size=hidden_size, 
            batch_first=True
        )
        self.embedding_table = nn.Embedding(source_vocab_size, embedding_dim)
        
    def forward(self, input_ids):
        
        input_sequence = self.embedding_table(input_ids)  # 3D tensor
        output_states, (final_h, final_c) = self.lstm_layer(input_sequence)
        
        return output_states, final_h
    
class Seq2SeqAttentionMechanism(nn.Module):
    """
    实现dot-product的Attention
    """
    def __init__(self):
        super(Seq2SeqAttentionMechanism, self).__init__()
        
    def forward(self, decoder_state_t, encoder_states):
        
        bs, source_length, hidden_size = encoder_states.shape
        
        decoder_state_t = decoder_state_t.unsqueeze(1)
        decoder_state_t = torch.tile(decoder_state_t, dims=(1, source_length, 1))  # 3D tensor
        
        # 点乘注意力
        score = torch.sum(decoder_state_t * encoder_states, dim=-1)  # [bs, source_length]
        
        attn_prob = F.softmax(score, dim=-1)  # softmax
        
        context = torch.sum(attn_prob.unsqueeze(-1)*encoder_states, 1)  # 
        
        return attn_prob, context
    

class Seq2SeqDecoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id):
        super(Seq2SeqDecoder, self).__init__()
        
        self.lstm_cell = torch.nn.LSTMCell(embedding_dim, hidden_size)
        
        # num_classes 就是 target_vocab_size
        self.proj_layer = nn.Linear(hidden_size*2, num_classes)  # context vector 和 hidden state
        self.attention_mechanism = Seq2SeqAttentionMechanism()  # 注意力机制
        
        self.num_classes = num_classes  # 最后的分类层
        self.embedding_table = torch.nn.Embedding(target_vocab_size, embedding_dim)
        
        # 推理时,从start id开始,一直到end id结束,两个token
        self.start_id = start_id  # seq2seq任务,训练传入target seq,需要偏移
        self.end_id = end_id
        
    def forward(self, shifted_target_ids, encoder_states):
        """
        训练阶段调研,teacher-force mode
        """
        
        shifted_target = self.embedding_table(shifted_target_ids)  # 2维张量变成3维
        
        bs, target_length, embedding_dim = shifted_target.shape  # 目标序列的长度
        bs, source_length, hidden_size = encoder_states.shape  # 原序列的长度
        
        logits = torch.zeros(bs, target_length, self.num_classes)
        probs = torch.zeros(bs, target_length, source_length)
        
        # 每一步都需要计算上下文的向量
        for t in range(target_length):
            # 已知id
            decoder_input_t = shifted_target[:, t, :]  # [bs, embedding_dim], 第t时刻的值
            
            # 单步执行lstm_cell
            if t == 0:
                h_t, c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
            
            # 解码器的状态,和全部编码的状态
            attn_prob, context = self.attention_mechanism(h_t, encoder_states)
            
            # context vector 和 decode hidden state
            decoder_output = torch.cat((context, h_t), -1) 
            
            logits[:, t, :] = self.proj_layer(decoder_output)
            probs[:, t, :] = attn_prob
        
        return probs, logits
    
    def inference(self, encoder_states): 
        """
        推理阶段调用
        """
        target_id = self.start_id  # 起始id
        h_t = None
        result = []
        
        while True:
            decoder_input_t = self.embedding_table(target_id)
            if h_t is None:
                h_t, c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
                
            attn_prob, context = self.attention_mechanism(h_t, encoder_states)
            
            decoder_output = torch.cat((context, h_t), -1)
            logits = self.proj_layer(decoder_output)
            
            target_id = torch.argmax(logits, -1)  # 上一个时刻的id,预测下一时刻的输入
            result.append(target_id)
            
            if torch.any(target_id == self.end_id):  # 预测到end_id结束
                print("stop decoding!")
                break
                
        predicted_ids = torch.stack(result, dim=0)
        
        return predicted_ids
    

class Model(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_classes,
                 source_vocab_size, target_vocab_size, start_id, end_id):
        super(Model, self).__init__()
        self.encoder = Seq2SeqEncoder(embedding_dim, hidden_size, source_vocab_size)
        self.decoder = Seq2SeqDecoder(embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id)
        
    def forward(self, input_sequence_ids, shifted_target_ids):
        """
        训练:input_sequence_ids输入句子的ids,shifted_target_ids输出句子的ids
        """
        encoder_states, final_h = self.encoder(input_sequence_ids)
        probs, logits = self.decoder(shifted_target_ids, encoder_states)
        return probs, logits
    
    def infer(self):
        pass
    
if __name__ == '__main__':
    """
    单步的模拟,如果要训练,需要引入dataloader、mini-batch training
    """
    source_length = 3
    target_length = 4
    embedding_dim = 8
    hidden_size = 16
    num_classes = 10
    bs = 2
    start_id = end_id = 0
    source_vocab_size = 100
    target_vocab_size = 100
    
    # 源序列的ids
    input_sequence_ids = torch.randint(source_vocab_size, size=(bs, source_length)).to(torch.int32)
    
    target_ids = torch.randint(target_vocab_size, size=(bs, target_length))
    target_ids = torch.cat((target_ids, end_id*torch.ones(bs, 1)), dim=1).to(torch.int32)  # 最后一位是end_id
    
    # shifted ids
    shifted_target_ids = torch.cat((start_id*torch.ones(bs, 1), target_ids[:, 1:]), dim=1).to(torch.int32)
    
    model = Model(embedding_dim, hidden_size, num_classes, source_vocab_size, target_vocab_size, start_id, end_id)
    probs, logits = model(input_sequence_ids, shifted_target_ids)
    print(probs.shape)
    print(logits.shape)

你可能感兴趣的:(深度学习,pytorch,深度学习,python)