seq2seq模型可以运用到机器翻译、序列建模、词性标注、缺失值预测、音乐建模、语音翻译等等任务
seq2seq主要分为encoder和decoder
encoder对语言上下文表征的建模
decoder是解码器,将编码和注意力机制融合,解码器在每一步预测的时候需要对编码器输出状态做一个加权的求和,有forward函数和inference函数,forward函数用于训练,inference用于预测推理阶段
注意力机制有很多,如局部或全局、单调或非单调的,点乘的或加法的,可以根据特定的论文、特定复杂度、特定的任务去修改
难点也主要在decoder的inference函数
encoder比较简单,直接跳过
class Seq2SeqEncoder(nn.Module):
""" 实现基于LSTM,当然也可以用其他类型如CNN或者TransformerEncoder"""
def __init__(self, emb_dim, hidden_size, source_vocab_size):
super(Seq2SeqEncoder, self).__init__()
self.lstm_layer = nn.LSTM(input_size=emb_dim,
hidden_size=hidden_size,
batch_first=True)
self.emb_table = nn.Embedding(source_vocab_size, emb_dim)
def forward(self, input_ids):
input_sequence = self.emb_table(input_ids)
output_sates, (final_h, final_c) = self.lstm_layer(input_sequence)
return output_sates, final_h
decoder在训练中会传入真实的seq_ids作为输入,但是这个seq_ids是带偏移的,最前面会补一个start_id,最后面会添加一个end_id,推理的时候以start_id作为开始,以end_id作为结束
class Seq2SeqDecoder(nn.Module):
""" 解码器部分
为什么用LSTM cell,因为decode每一步都要计算注意力机制
"""
def __init__(self, emb_dim, hidden_size, num_class, target_vocab_size, start_id, end_id):
super().__init__()
self.start_id = start_id
self.end_id = end_id
self.lstm_cell = nn.LSTMCell(emb_dim, hidden_size)
# ht和context拼接起来,再送入MLP做分类
self.proj_layer = nn.Linear(hidden_size * 2, num_class)
self.attention_mechanism = Seq2SeqAttentionMechanism()
self.num_class = num_class
self.emb_tabel = nn.Embedding(target_vocab_size, emb_dim)
def forward(self, shifted_target_ids, encoder_states):
"""
训练阶段,teacher-force mode,有一个完整的目标序列作为teacher指导训练
输出:
probs解码器每一步注意力机制权重
logits每一步单词的概率
"""
shift_target = self.emb_tabel(shifted_target_ids)
bs, target_len, emb_dim = shift_target.shape
bs, source_len, hidden_size = encoder_states.shape
logits = torch.zeros(bs, target_len, self.num_class)
probs = torch.zeros(bs, target_len, source_len)
h_t, c_t = torch.zeros(bs, hidden_size)
for t in range(target_len):
decoder_input_t = shift_target[:, t, :] # [bs, emb_dim]
if t == 0:
# 第0时刻是没有init_h和init_c
h_t, c_t = self.lstm_cell(decoder_input_t)
else:
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
prob, context = self.attention_mechanism(h_t, encoder_states)
decoder_output = torch.cat([context, h_t], dim=-1)
logits[:, t, :] = self.proj_layer(decoder_output)
probs[:, t, :] = prob
return probs, logits
def inference(self, encoder_states):
"""
推理阶段使用
"""
target_id = self.start_id
h_t = c_t = None
result = []
while True:
# 获得input_emb
decoder_input_t = self.emb_tabel(target_id)
# 送入lstm cell
if h_t is None:
h_t, c_t = self.lstm_cell(decoder_input_t)
else:
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
# 计算注意力
attn_prob, context = self.attention_mechanism(h_t, encoder_states)
decoder_output = torch.cat([context, h_t], dim=-1)
# 送入MLP预测
logits = self.proj_layer(decoder_output)
# 看分到哪一类
target_id = torch.argmax(logits, -1)
result.append(target_id)
# 中止条件
if torch.any(target_id == self.end_id):
print("stop decoding")
break
predicated_ids = torch.stack(result, dim=0)
return predicated_ids
完整代码见github
在做seq2seq 离散符号任务时,如文本到文本的变换,图像语义分割,第一步要做的是tokenization分词
还有一些预处理工作,比如表情,数字的处理
第二步,token2idx,将token变成一个个的index
构建字典
第三步,添加SOS和EOS,start of sentence和end of sequence的索引,如假定padding value = 0,SOS=1,EOS=2
第四步,编写完dataset之后,传入dataloader的时候,传入collate_fn对原始的batch进行后处理
进行padding操作 target = pad_sentence, nn.utils.rnn.pad_sequence
label=target[:, 1:] 去掉SOS
decoder_input = target[:, :-1] ,去掉EOS
编码器Encoder是非流式的,编码器模型有很多种选择
注意力机制,计算解码器与编码器之间关联性,
计算score有content-based,考虑encoder和decoder之间的state
location-based,考虑decoder state和previous location,把上一步的对齐信息纳入当前score的计算
hybrid,既考虑encoder state也将previous对齐信息纳入score计算,混合模式
计算score方式,点乘方式,加法形式,乘形式
计算context,softmax score得到权重与value加权求和
解码器部分,可以使用RNN CELL,或者transformer等等
训练使用的是teacher-force,解码器每一步都有一个输入,输入是当前要预测目标的上一时刻的真实值
推理是自回归推理,上一时刻的预测值送入下一时刻解码器作为输入,再去预测下一时刻的输出值
因为每句话长短不一样,需要添加masked sequence loss