1.rnn_decoder
def _extract_argmax_and_embed(embedding,
output_projection=None,
update_embedding=True):
"""Get a loop_function that extracts the previous symbol and embeds it.
Args:
embedding: embedding tensor for symbols.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
def loop_function(prev, _):
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = tf.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = tf.stop_gradient(emb_prev)
return emb_prev
return loop_function
tf.reset_default_graph()
enc_inputs = tf.placeholder(
tf.int32,
shape=[None, enc_sentence_length],
name='input_sentences')
sequence_lengths = tf.placeholder(
tf.int32,
shape=[None],
name='sentences_length')
dec_inputs = tf.placeholder(
tf.int32,
shape=[None, dec_sentence_length+1],
name='output_sentences')
# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])
with tf.device('/cpu:0'):
dec_Wemb = tf.get_variable('dec_word_emb',
initializer=tf.random_uniform([dec_vocab_size+2, dec_emb_size]))
with tf.variable_scope('encoder'):
enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size)
# enc_sent_len x batch_size x embedding_size
enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn(
cell=enc_cell,
inputs=tf.unstack(enc_inputs_t),
sequence_length=sequence_lengths,
dtype=tf.float32)
dec_outputs = []
dec_predictions = []
with tf.variable_scope('decoder'):
dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2)
# EmbeddingWrapper & tf.unstack(dec_inputs_t) raises dimension error
dec_emb_inputs = tf.nn.embedding_lookup(dec_Wemb, dec_inputs_t)
# dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
dec_outputs, dec_last_state = rnn_decoder(
decoder_inputs=tf.unstack(dec_emb_inputs),
initial_state=enc_last_state,
cell=dec_cell,
loop_function=_extract_argmax_and_embed(dec_Wemb))
# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])
# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)
tf.reset_default_graph()
enc_inputs = tf.placeholder(
tf.int32,
shape=[None, enc_sentence_length],
name='input_sentences')
sequence_lengths = tf.placeholder(
tf.int32,
shape=[None],
name='sentences_length')
dec_inputs = tf.placeholder(
tf.int32,
shape=[None, dec_sentence_length+1],
name='output_sentences')
# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])
with tf.variable_scope('encoder'):
enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size)
# enc_sent_len x batch_size x embedding_size
enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn(
cell=enc_cell,
inputs=tf.unstack(enc_inputs_t),
sequence_length=sequence_lengths,
dtype=tf.float32)
dec_outputs = []
dec_predictions = []
with tf.variable_scope('decoder'):
dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2)
# dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
dec_outputs, dec_last_state = embedding_rnn_decoder(
decoder_inputs=tf.unstack(dec_inputs_t),
initial_state=enc_last_state,
cell=dec_cell,
num_symbols=dec_vocab_size+2,
embedding_size=dec_emb_size,
feed_previous=True)
# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])
# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)
tf.reset_default_graph()
enc_inputs = tf.placeholder(
tf.int32,
shape=[None, enc_sentence_length],
name='input_sentences')
sequence_lengths = tf.placeholder(
tf.int32,
shape=[None],
name='sentences_length')
dec_inputs = tf.placeholder(
tf.int32,
shape=[None, dec_sentence_length+1],
name='output_sentences')
# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
with tf.variable_scope("embedding_rnn_seq2seq"):
# dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
dec_outputs, dec_last_state = embedding_rnn_seq2seq(
encoder_inputs=tf.unstack(enc_inputs_t),
decoder_inputs=tf.unstack(dec_inputs_t),
cell=rnn_cell,
num_encoder_symbols=enc_vocab_size+1,
num_decoder_symbols=dec_vocab_size+2,
embedding_size=enc_emb_size,
feed_previous=True)
# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])
# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)