Seq2Seq模型应用案例( :
Seq2Seq是Encoder-Decoder(编码器与解码器)模型,输入是一个序列,输出也是一个序列,适用于输入序列与输出序列长度不等长的场景,如机器翻译、人机对话、聊天机器人等。
本案例参考网友的github代码,实现一个Seq2Seq对话的简单例子,训练模型的语料如下:
编码器语句序列:
['Hi What is your name?', 'Nice to meet you!'],
['Which programming language do you use?', 'See you later.'],
['Where do you live?', 'What is your major?'],
['What do you want to drink?', 'What is your favorite beer?']]
解码器语句序列:
['Hi this is Jaemin.', 'Nice to meet you too!'],
['I like Python.', 'Bye Bye.'],
['I live in Seoul, South Korea.', 'I study industrial engineering.'],
['Beer please!', 'Leffe brown!']]
Seq2Seq模型代码:
一,创建一个配置类,配置类里面放各种超参数。内核为BasicLSTMCell。
hidden_size = 30
enc_emb_size = 30
enc_emb_size = 30
二,创建Seq2SeqModel模型类:
1,__init__初始化
2,add_placeholders 数据占位符
self.enc_inputs : shape=[None, enc_sentence_length] (编码器句子批次大小,编码器的句子最大长度)
self.enc_sequence_lengt :shape=[None,] [编码器句子的批次大小,]
训练模式下:
self.dec_inputs :shape=[None, dec_sentence_length+1] (解码器批次大小,解码器的句子最大长度+1 ):因为解码器的句子在进行单词-编码的sent2idx转换时候,每个句子前面加上了一个开始字符'_GO'的数字[0]。
self.dec_sequence_length shape=[None,] [解码器句子的批次大小,]
3,add_encoder 编码器函数
self.enc_Wemb [enc_vocab_size+1, self.enc_emb_size]) 因为编码器句子补全时加上一个PAD字符对应的数字[0]
enc_emb_inputs :[Batch_size , enc_sent_len , embedding_size]
enc_cell = self.cell(self.hidden_size)
enc_outputs: [batch_size , enc_sent_len , embedding_size]
enc_last_state: [batch_size , embedding_size]
4,add_decoder解码器 函数
self.dec_Wemb :[dec_vocab_size+2, self.dec_emb_size] 加2是因为解码器句子每个句子要加上一个开始字符'_GO'对应的数字[0],以及句子补全的字符'_PAD'对应的数字[1]
max_dec_len:self.dec_sequence_length+1, 解码器中每个句子长度的最大值加1(加上每个句子的开始符号GO)。
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=dec_emb_inputs,
sequence_length=self.dec_sequence_length+1,
time_major=False,
name='training_helper')
训练模型时使用training_helper,因为我们已经明确了解码器的每个序列,把每个序列的词作为循环输入,而不是将预测的词作为循环输入,提升了模型训练时的准确度。
5.预测时略有不同:不需要使用TrainingHelper,使用GreedyEmbeddingHelper方法tf.contrib.seq2seq.GreedyEmbeddingHelper:每一个预测的单词作为RNN的下一个输入单词再进行预测, 每一个单词的预测输出继续作为循环输入,再预测出一个结果!
6. 模型训练:
add_training_op: 使用tf.train.RMSPropOptimizer优化器
save:保存模型
restore:加载模型
summary:输出到tenorboard展示。
build:将add_placeholders、add_encoder、add_decoder组装。
train:训练模型。tqdm用于记录进度条。
问题:Seq2Seq模型在模型训练时测试很好,但是在测试集预测推理时不理想,模型训练与模型预测的差异性如下
Seq2Seq示意图:
Seq2Seq案例的全部代码:https://github.com/duanzhihua/tf_tutorial_plus/blob/master/RNN_seq2seq/contrib_seq2seq/01_TrainingHelper.ipynb
class Seq2SeqModel(object):
def __init__(self, config, mode='training'):
assert mode in ['training', 'evaluation', 'inference']
self.mode = mode
# Model
self.hidden_size = config.hidden_size
self.enc_emb_size = config.enc_emb_size
self.dec_emb_size = config.dec_emb_size
self.cell = config.cell
# Training
self.optimizer = config.optimizer
self.n_epoch = config.n_epoch
self.learning_rate = config.learning_rate
# Checkpoint Path
self.ckpt_dir = config.ckpt_dir
def add_placeholders(self):
self.enc_inputs = tf.placeholder(
tf.int32,
shape=[None, enc_sentence_length],
name='input_sentences')
self.enc_sequence_length = tf.placeholder(
tf.int32,
shape=[None,],
name='input_sequence_length')
if self.mode == 'training':
self.dec_inputs = tf.placeholder(
tf.int32,
shape=[None, dec_sentence_length+1],
name='target_sentences')
self.dec_sequence_length = tf.placeholder(
tf.int32,
shape=[None,],
name='target_sequence_length')
def add_encoder(self):
with tf.variable_scope('Encoder') as scope:
with tf.device('/cpu:0'):
self.enc_Wemb = tf.get_variable('embedding',
initializer=tf.random_uniform([enc_vocab_size+1, self.enc_emb_size]),
dtype=tf.float32)
# [Batch_size x enc_sent_len x embedding_size]
enc_emb_inputs = tf.nn.embedding_lookup(
self.enc_Wemb, self.enc_inputs, name='emb_inputs')
enc_cell = self.cell(self.hidden_size)
# enc_outputs: [batch_size x enc_sent_len x embedding_size]
# enc_last_state: [batch_size x embedding_size]
enc_outputs, self.enc_last_state = tf.nn.dynamic_rnn(
cell=enc_cell,
inputs=enc_emb_inputs,
sequence_length=self.enc_sequence_length,
time_major=False,
dtype=tf.float32)
def add_decoder(self):
with tf.variable_scope('Decoder') as scope:
with tf.device('/cpu:0'):
self.dec_Wemb = tf.get_variable('embedding',
initializer=tf.random_uniform([dec_vocab_size+2, self.dec_emb_size]),
dtype=tf.float32)
dec_cell = self.cell(self.hidden_size)
# output projection (replacing `OutputProjectionWrapper`)
output_layer = Dense(dec_vocab_size+2, name='output_projection')
if self.mode == 'training':
# maxium unrollings in current batch = max(dec_sent_len) + 1(GO symbol)
max_dec_len = tf.reduce_max(self.dec_sequence_length+1, name='max_dec_len')
dec_emb_inputs = tf.nn.embedding_lookup(
self.dec_Wemb, self.dec_inputs, name='emb_inputs')
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=dec_emb_inputs,
sequence_length=self.dec_sequence_length+1,
time_major=False,
name='training_helper')
training_decoder = tf.contrib.seq2seq.BasicDecoder(
cell=dec_cell,
helper=training_helper,
initial_state=self.enc_last_state,
output_layer=output_layer)
train_dec_outputs, train_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
training_decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=max_dec_len)
# dec_outputs: collections.namedtuple(rnn_outputs, sample_id)
# dec_outputs.rnn_output: [batch_size x max(dec_sequence_len) x dec_vocab_size+2], tf.float32
# dec_outputs.sample_id [batch_size], tf.int32
# logits: [batch_size x max_dec_len x dec_vocab_size+2]
logits = tf.identity(train_dec_outputs.rnn_output, name='logits')
# targets: [batch_size x max_dec_len x dec_vocab_size+2]
targets = tf.slice(self.dec_inputs, [0, 0], [-1, max_dec_len], 'targets')
# masks: [batch_size x max_dec_len]
# => ignore outputs after `dec_senquence_length+1` when calculating loss
masks = tf.sequence_mask(self.dec_sequence_length+1, max_dec_len, dtype=tf.float32, name='masks')
# Control loss dimensions with `average_across_timesteps` and `average_across_batch`
# internal: `tf.nn.sparse_softmax_cross_entropy_with_logits`
self.batch_loss = tf.contrib.seq2seq.sequence_loss(
logits=logits,
targets=targets,
weights=masks,
name='batch_loss')
# prediction sample for validation
self.valid_predictions = tf.identity(train_dec_outputs.sample_id, name='valid_preds')
# List of training variables
# self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
elif self.mode == 'inference':
batch_size = tf.shape(self.enc_inputs)[0:1]
start_tokens = tf.zeros(batch_size, dtype=tf.int32)
inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding=self.dec_Wemb,
start_tokens=start_tokens,
end_token=1)
inference_decoder = tf.contrib.seq2seq.BasicDecoder(
cell=dec_cell,
helper=inference_helper,
initial_state=self.enc_last_state,
output_layer=output_layer)
infer_dec_outputs, infer_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
inference_decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=dec_sentence_length)
# [batch_size x dec_sentence_length], tf.int32
self.predictions = tf.identity(infer_dec_outputs.sample_id, name='predictions')
# equivalent to tf.argmax(infer_dec_outputs.rnn_output, axis=2, name='predictions')
# List of training variables
# self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
def add_training_op(self):
self.training_op = self.optimizer(self.learning_rate, name='training_op').minimize(self.batch_loss)
def save(self, sess, var_list=None, save_path=None):
print(f'Saving model at {save_path}')
if hasattr(self, 'training_variables'):
var_list = self.training_variables
saver = tf.train.Saver(var_list)
saver.save(sess, save_path, write_meta_graph=False)
def restore(self, sess, var_list=None, ckpt_path=None):
if hasattr(self, 'training_variables'):
var_list = self.training_variables
self.restorer = tf.train.Saver(var_list)
self.restorer.restore(sess, ckpt_path)
print('Restore Finished!')
def summary(self):
summary_writer = tf.summary.FileWriter(
logdir=self.ckpt_dir,
graph=tf.get_default_graph())
def build(self):
self.add_placeholders()
self.add_encoder()
self.add_decoder()
def train(self, sess, data, from_scratch=False,
load_ckpt=None, save_path=None):
# Restore Checkpoint
if from_scratch is False and os.path.isfile(load_ckpt):
self.restore(sess, load_ckpt)
# Add Optimizer to current graph
self.add_training_op()
sess.run(tf.global_variables_initializer())
input_batches, target_batches = data
loss_history = []
for epoch in tqdm(range(self.n_epoch)):
all_preds = []
epoch_loss = 0
for input_batch, target_batch in zip(input_batches, target_batches):
input_batch_tokens = []
target_batch_tokens = []
enc_sentence_lengths = []
dec_sentence_lengths = []
for input_sent in input_batch:
tokens, sent_len = sent2idx(input_sent)
input_batch_tokens.append(tokens)
enc_sentence_lengths.append(sent_len)
for target_sent in target_batch:
tokens, sent_len = sent2idx(target_sent,
vocab=dec_vocab,
max_sentence_length=dec_sentence_length,
is_target=True)
target_batch_tokens.append(tokens)
dec_sentence_lengths.append(sent_len)
# Evaluate 3 ops in the graph
# => valid_predictions, loss, training_op(optimzier)
batch_preds, batch_loss, _ = sess.run(
[self.valid_predictions, self.batch_loss, self.training_op],
feed_dict={
self.enc_inputs: input_batch_tokens,
self.enc_sequence_length: enc_sentence_lengths,
self.dec_inputs: target_batch_tokens,
self.dec_sequence_length: dec_sentence_lengths,
})
# loss_history.append(batch_loss)
epoch_loss += batch_loss
all_preds.append(batch_preds)
loss_history.append(epoch_loss)
# Logging every 400 epochs
if epoch % 400 == 0:
print('Epoch', epoch)
for input_batch, target_batch, batch_preds in zip(input_batches, target_batches, all_preds):
for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
print(f'\tInput: {input_sent}')
print(f'\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
print(f'\tTarget:, {target_sent}')
print(f'\tepoch loss: {epoch_loss:.2f}\n')
if save_path:
self.save(sess, save_path=save_path)
return loss_history
def inference(self, sess, data, load_ckpt):
self.restore(sess, ckpt_path=load_ckpt)
input_batch, target_batch = data
batch_preds = []
batch_tokens = []
batch_sent_lens = []
for input_sent in input_batch:
tokens, sent_len = sent2idx(input_sent)
batch_tokens.append(tokens)
batch_sent_lens.append(sent_len)
batch_preds = sess.run(
self.predictions,
feed_dict={
self.enc_inputs: batch_tokens,
self.enc_sequence_length: batch_sent_lens,
})
for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
print('Input:', input_sent)
print('Prediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
print('Target:', target_sent, '\n')