rnn decoder

1.rnn_decoder

def _extract_argmax_and_embed(embedding,
                              output_projection=None,
                              update_embedding=True):
  """Get a loop_function that extracts the previous symbol and embeds it.
  Args:
    embedding: embedding tensor for symbols.
    output_projection: None or a pair (W, B). If provided, each fed previous
      output will first be multiplied by W and added B.
    update_embedding: Boolean; if False, the gradients will not propagate
      through the embeddings.
  Returns:
    A loop function.
  """

  def loop_function(prev, _):
    if output_projection is not None:
      prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
    prev_symbol = tf.argmax(prev, 1)
    # Note that gradients will not propagate through the second parameter of
    # embedding_lookup.
    emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
    if not update_embedding:
      emb_prev = tf.stop_gradient(emb_prev)
    return emb_prev

  return loop_function
tf.reset_default_graph()

enc_inputs = tf.placeholder(
    tf.int32,
    shape=[None, enc_sentence_length],
    name='input_sentences')

sequence_lengths = tf.placeholder(
    tf.int32,
    shape=[None],
    name='sentences_length')

dec_inputs = tf.placeholder(
    tf.int32,
    shape=[None, dec_sentence_length+1],
    name='output_sentences')

# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])

with tf.device('/cpu:0'):
    dec_Wemb = tf.get_variable('dec_word_emb',
        initializer=tf.random_uniform([dec_vocab_size+2, dec_emb_size]))
    
with tf.variable_scope('encoder'):
    enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
    enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size)
    
    # enc_sent_len x batch_size x embedding_size
    enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn(
        cell=enc_cell,
        inputs=tf.unstack(enc_inputs_t),
        sequence_length=sequence_lengths,
        dtype=tf.float32)

dec_outputs = []
dec_predictions = []
with tf.variable_scope('decoder'):
    dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
    dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2)
    
    # EmbeddingWrapper & tf.unstack(dec_inputs_t) raises dimension error
    dec_emb_inputs = tf.nn.embedding_lookup(dec_Wemb, dec_inputs_t)
    
    # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
    dec_outputs, dec_last_state = rnn_decoder(
        decoder_inputs=tf.unstack(dec_emb_inputs),
        initial_state=enc_last_state,
        cell=dec_cell,
        loop_function=_extract_argmax_and_embed(dec_Wemb))

# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])

# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    labels=labels, logits=logits))

# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)


2.embedding_runn_decoder

tf.reset_default_graph()

enc_inputs = tf.placeholder(
    tf.int32,
    shape=[None, enc_sentence_length],
    name='input_sentences')

sequence_lengths = tf.placeholder(
    tf.int32,
    shape=[None],
    name='sentences_length')

dec_inputs = tf.placeholder(
    tf.int32,
    shape=[None, dec_sentence_length+1],
    name='output_sentences')

# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])
    
with tf.variable_scope('encoder'):
    enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
    enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size)
    
    # enc_sent_len x batch_size x embedding_size
    enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn(
        cell=enc_cell,
        inputs=tf.unstack(enc_inputs_t),
        sequence_length=sequence_lengths,
        dtype=tf.float32)

dec_outputs = []
dec_predictions = []
with tf.variable_scope('decoder'):
    dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
    dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2)
    
    # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
    dec_outputs, dec_last_state = embedding_rnn_decoder(
        decoder_inputs=tf.unstack(dec_inputs_t),
        initial_state=enc_last_state,
        cell=dec_cell,
        num_symbols=dec_vocab_size+2,
        embedding_size=dec_emb_size,
        feed_previous=True)

        
# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])

# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)
        
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    labels=labels, logits=logits))

# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)

3.embedding_rnn_seq2seq

tf.reset_default_graph()

enc_inputs = tf.placeholder(
    tf.int32,
    shape=[None, enc_sentence_length],
    name='input_sentences')

sequence_lengths = tf.placeholder(
    tf.int32,
    shape=[None],
    name='sentences_length')

dec_inputs = tf.placeholder(
    tf.int32,
    shape=[None, dec_sentence_length+1],
    name='output_sentences')

# batch_major => time_major
enc_inputs_t = tf.transpose(enc_inputs, [1,0])
dec_inputs_t = tf.transpose(dec_inputs, [1,0])

rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

with tf.variable_scope("embedding_rnn_seq2seq"):
    # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size]
    dec_outputs, dec_last_state = embedding_rnn_seq2seq(
        encoder_inputs=tf.unstack(enc_inputs_t),
        decoder_inputs=tf.unstack(dec_inputs_t),
        cell=rnn_cell,
        num_encoder_symbols=enc_vocab_size+1,
        num_decoder_symbols=dec_vocab_size+2,
        embedding_size=enc_emb_size,
        feed_previous=True)

# predictions: [batch_size x dec_sentence_lengths+1]
predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0])

# labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2]
labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2)
logits = tf.stack(dec_outputs)
        
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    labels=labels, logits=logits))

# training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)


你可能感兴趣的:(tensorflow)