Seq2Seq模型应用案例之ScheduledEmbeddingTrainingHelper:
Tensorflow最新的Seq2Seq案例请参考官网:https://github.com/tensorflow/nmt 这里不再赘述。
在之前的博客中https://blog.csdn.net/duan_zhihua/article/details/87114665提及模型训练与模型预测的差异性,Tensorflow提供了ScheduledEmbedding的机制,训练时候解码器加入了抽样概率,按epoch的进度逐渐提高抽样概率:概率抽样为0的时候ScheduledEmbedidngTrainingHelper相当于TrainingHelper,概率抽样为1的时候ScheduledEmbedidngTrainingHelper相当于GreedyEmbeddingHelper,在0到1之间按照概率抽样目标词做预测。ScheduledEmbeddingTrainingHelper比没有实施计划采样的效果较好。
# 0.0 ≤ sampling_probability ≤ 1.0
# 0.0: no sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `TrainingHelper` 可能过拟合!
# 1.0: always sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `GreedyEmbeddingHelper`
# Inceasing sampling over steps => Curriculum Learning
Seq2SeqModel代码
https://github.com/duanzhihua/tf_tutorial_plus/blob/master/RNN_seq2seq/contrib_seq2seq/02_ScheduledEmbeddingTrainingHelper.ipynb
class Seq2SeqModel(object):
def __init__(self, config, mode='training'):
assert mode in ['training', 'evaluation', 'inference']
self.mode = mode
# Model
self.hidden_size = config.hidden_size
self.enc_emb_size = config.enc_emb_size
self.dec_emb_size = config.dec_emb_size
self.cell = config.cell
# Training
self.optimizer = config.optimizer
self.n_epoch = config.n_epoch
self.learning_rate = config.learning_rate
# Sampling Probability
self.sampling_probability_list = config.sampling_probability_list
# Checkpoint path
self.ckpt_dir = config.ckpt_dir
def add_placeholders(self):
self.enc_inputs = tf.placeholder(
tf.int32,
shape=[None, enc_sentence_length],
name='input_sentences')
self.enc_sequence_length = tf.placeholder(
tf.int32,
shape=[None,],
name='input_sequence_length')
if self.mode == 'training':
self.dec_inputs = tf.placeholder(
tf.int32,
shape=[None, dec_sentence_length+1],
name='target_sentences')
self.dec_sequence_length = tf.placeholder(
tf.int32,
shape=[None,],
name='target_sequence_length')
self.sampling_probability = tf.placeholder(
tf.float32,
shape=[],
name='sampling_probability')
# 0.0 ≤ sampling_probability ≤ 1.0
# 0.0: no sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `TrainingHelper`
# 1.0: always sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `GreedyEmbeddingHelper`
# Inceasing sampling over steps => Curriculum Learning
def add_encoder(self):
with tf.variable_scope('Encoder') as scope:
with tf.device('/cpu:0'):
self.enc_Wemb = tf.get_variable('embedding',
initializer=tf.random_uniform([enc_vocab_size+1, self.enc_emb_size]),
dtype=tf.float32)
# [Batch_size x enc_sent_len x embedding_size]
enc_emb_inputs = tf.nn.embedding_lookup(
self.enc_Wemb, self.enc_inputs, name='emb_inputs')
enc_cell = self.cell(self.hidden_size)
# enc_outputs: [batch_size x enc_sent_len x embedding_size]
# enc_last_state: [batch_size x embedding_size]
enc_outputs, self.enc_last_state = tf.nn.dynamic_rnn(
cell=enc_cell,
inputs=enc_emb_inputs,
sequence_length=self.enc_sequence_length,
time_major=False,
dtype=tf.float32)
def add_decoder(self):
with tf.variable_scope('Decoder') as scope:
with tf.device('/cpu:0'):
self.dec_Wemb = tf.get_variable('embedding',
initializer=tf.random_uniform([dec_vocab_size+2, self.dec_emb_size]),
dtype=tf.float32)
dec_cell = self.cell(self.hidden_size)
# output projection (replacing `OutputProjectionWrapper`)
output_layer = Dense(dec_vocab_size+2, name='output_projection')
if self.mode == 'training':
# maximum unrollings in current batch = max(dec_sent_len) + 1(GO symbol)
max_dec_len = tf.reduce_max(self.dec_sequence_length+1, name='max_dec_len')
dec_emb_inputs = tf.nn.embedding_lookup(
self.dec_Wemb, self.dec_inputs, name='emb_inputs')
training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
inputs=dec_emb_inputs,
sequence_length=self.dec_sequence_length+1,
embedding=self.dec_Wemb,
sampling_probability=self.sampling_probability,
time_major=False,
name='training_helper')
training_decoder = tf.contrib.seq2seq.BasicDecoder(
cell=dec_cell,
helper=training_helper,
initial_state=self.enc_last_state,
output_layer=output_layer)
train_dec_outputs, train_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
training_decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=max_dec_len)
# dec_outputs: collections.namedtuple(rnn_outputs, sample_id)
# dec_outputs.rnn_output: [batch_size x max(dec_sequence_length) x dec_vocab_size+2], tf.float32
# dec_outputs.sample_id [batch_size], tf.int32
# logits: [batch_size x max_dec_len x dec_vocab_size+2]
logits = tf.identity(train_dec_outputs.rnn_output, name='logits')
# targets: [batch_size x max_dec_len x dec_vocab_size+2]
targets = tf.slice(self.dec_inputs, [0, 0], [-1, max_dec_len], 'targets')
# masks: [batch_size x max_dec_len]
# => ignore outputs after `dec_senquence_length+1` when calculating loss
masks = tf.sequence_mask(self.dec_sequence_length+1, max_dec_len, dtype=tf.float32, name='masks')
# Control loss dimensions with `average_across_timesteps` and `average_across_batch`
# internal: `tf.nn.sparse_softmax_cross_entropy_with_logits`
self.batch_loss = tf.contrib.seq2seq.sequence_loss(
logits=logits,
targets=targets,
weights=masks,
name='batch_loss')
# prediction sample for validation
# some sample_id are overwritten with '-1's
self.valid_predictions = tf.argmax(logits, axis=2, name='valid_predictions')
# List of training variables
# self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
elif self.mode == 'inference':
batch_size = tf.shape(self.enc_inputs)[0:1]
start_tokens = tf.zeros(batch_size, dtype=tf.int32)
inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding=self.dec_Wemb,
start_tokens=start_tokens,
end_token=1)
inference_decoder = tf.contrib.seq2seq.BasicDecoder(
cell=dec_cell,
helper=inference_helper,
initial_state=self.enc_last_state,
output_layer=output_layer)
infer_dec_outputs, infer_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
inference_decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=dec_sentence_length)
# [batch_size x dec_sentence_length], tf.int32
self.predictions = tf.identity(infer_dec_outputs.sample_id, name='predictions')
# equivalent to tf.argmax(infer_dec_outputs.rnn_output, axis=2, name='predictions')
# List of training variables
# self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
def add_training_op(self):
self.training_op = self.optimizer(self.learning_rate, name='training_op').minimize(self.batch_loss)
def save(self, sess, var_list=None, save_path=None):
print(f'Saving model at {save_path}')
if hasattr(self, 'training_variables'):
var_list = self.training_variables
saver = tf.train.Saver(var_list)
saver.save(sess, save_path, write_meta_graph=False)
def restore(self, sess, var_list=None, ckpt_path=None):
if hasattr(self, 'training_variables'):
var_list = self.training_variables
self.restorer = tf.train.Saver(var_list)
self.restorer.restore(sess, ckpt_path)
print('Restore Finished!')
def summary(self):
summary_writer = tf.summary.FileWriter(
logdir=self.ckpt_dir,
graph=tf.get_default_graph())
def build(self):
self.add_placeholders()
self.add_encoder()
self.add_decoder()
def train(self, sess, data, from_scratch=False, load_ckpt=None, save_path=None):
# Restore Checkpoint
if from_scratch is False and os.path.isfile(load_ckpt):
self.restore(sess, load_ckpt)
# Add Optimizer to current graph
self.add_training_op()
sess.run(tf.global_variables_initializer())
input_batches, target_batches = data
loss_history = []
for epoch in tqdm(range(self.n_epoch)):
all_preds = []
epoch_loss = 0
for input_batch, target_batch in zip(input_batches, target_batches):
input_batch_tokens = []
target_batch_tokens = []
input_batch_sent_lens = []
target_batch_sent_lens = []
for input_sent in input_batch:
tokens, sent_len = sent2idx(input_sent)
input_batch_tokens.append(tokens)
input_batch_sent_lens.append(sent_len)
for target_sent in target_batch:
tokens, sent_len = sent2idx(target_sent,
vocab=dec_vocab,
max_sentence_length=dec_sentence_length,
is_target=True)
target_batch_tokens.append(tokens)
target_batch_sent_lens.append(sent_len)
# Evaluate 3 ops in the graph
# => valid_predictions, loss, training_op(optimzier)
batch_valid_preds, batch_loss, _ = sess.run(
[self.valid_predictions, self.batch_loss, self.training_op],
feed_dict={
self.enc_inputs: input_batch_tokens,
self.enc_sequence_length: input_batch_sent_lens,
self.dec_inputs: target_batch_tokens,
self.dec_sequence_length: target_batch_sent_lens,
self.sampling_probability: self.sampling_probability_list[epoch]
}
)
# loss_history.append(batch_loss)
epoch_loss += batch_loss
all_preds.append(batch_valid_preds)
loss_history.append(epoch_loss)
# Logging every 400 epochs
if epoch % 400 == 0:
print('Epoch', epoch)
print(f'Sampling probability: {self.sampling_probability_list[epoch]:.3f}')
for input_batch, target_batch, batch_preds in zip(input_batches, target_batches, all_preds):
for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
print(f'\tInput: {input_sent}')
print(f'\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
print(f'\tTarget: {target_sent}\n')
print(f'\tepoch loss: {epoch_loss:.2f}\n')
if save_path:
self.save(sess, save_path=save_path)
return loss_history
def inference(self, sess, data, load_ckpt):
self.restore(sess, ckpt_path=load_ckpt)
input_batch, target_batch = data
batch_preds = []
batch_tokens = []
batch_sent_lens = []
for input_sent in input_batch:
tokens, sent_len = sent2idx(input_sent)
batch_tokens.append(tokens)
batch_sent_lens.append(sent_len)
batch_preds = sess.run(
self.predictions,
feed_dict={
self.enc_inputs: batch_tokens,
self.enc_sequence_length: batch_sent_lens,
})
for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
print('Input:', input_sent)
print('Prediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
print('Target:', target_sent, '\n')