Tensorflow: recurrent neural network char-level 1

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.models.rnn import rnn, rnn_cell
from tensorflow.models.rnn import seq2seq

import collections

# @karpathy
data = open('ThreeMusketeers.txt').read()
chars = list(set(data))
data_size , vocab_size = len(data), len(chars)
print 'data has %d characters, %d unique.' %(data_size, vocab_size)
char_to_ix = {ch:i for i, ch in enumerate(chars)}
ix_to_char = {i:ch for i, ch in enumerate(chars)}
counter = collections.Counter(data)
counter = sorted(counter.items(), key=lambda x:-x[1])
for i in xrange(5):
    print counter[i]

corpus = [char_to_ix[c] for c in data]

batch_size = 1
seq_length = 1


hidden_size = 128
num_layers = 2
max_grad_norm = 5.0

an_lstm = rnn_cell.BasicLSTMCell(hidden_size)
multi_lstm = rnn_cell.MultiRNNCell([an_lstm] * num_layers)
x = tf.placeholder(tf.int32, [batch_size, seq_length])
y = tf.placeholder(tf.int32, [batch_size, seq_length])
init_state = multi_lstm.zero_state(batch_size, tf.float32)

with tf.variable_scope('rnn'):
    softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
    softmax_b = tf.get_variable('softmax_b', [vocab_size])
    with tf.device('/cpu:0'):
        embedding = tf.get_variable('embedding', [vocab_size, hidden_size])
        inputs = tf.nn.embedding_lookup(embedding, x)
        inputs = tf.split(1, seq_length, inputs)
        inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

#def loop(prev):
#    prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b)
#    prev_symbol = tf.stop_gradient(tf.arg_max(prev, 1))
#    return tf.nn.embedding_lookup(embedding, prev_symbol)

outputs, last_state = seq2seq.rnn_decoder(inputs, init_state, 
                                          multi_lstm, 
                                          loop_function=None, 
                                          scope='rnn')
# outputs is a list of 2D-Tensor with shape [batch_size , hidden_size]
# the len(outputs)) is seq_length

# first, hiddenlayer outputs belong to same sequence should be concatenated together 
out_conca = tf.concat(1, outputs) # [batch_size, hidden_size*seq_length]
# second, to get the softmax prob and add the fc layer, the out_conca's second dim should
# be reshaped to the size: hidden_size
# [batch_size*seq_length, hidden_size]
output = tf.reshape(out_conca, [-1, hidden_size])
# [batch_size*seq_length, vocab_size]
score = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
# [batch_size*seq_length, vocab_size]
probs = tf.nn.softmax(score)


init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

epoch = 20
batch_size = 100
snapshot = 5
save_step = 1
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('net_snapshot/')
saver.restore(sess, ckpt.model_checkpoint_path)
print ckpt.model_checkpoint_path

def weighted_pick(weights):
    t = np.cumsum(weights)
    s = np.sum(weights)
    return np.searchsorted(t, np.random.rand(1)*s).tolist()[0]

prime = 'The'
state = sess.run(multi_lstm.zero_state(1, tf.float32))
for c in prime[-1]:
    ix = np.zeros((1,1))
    ix[0,0] = char_to_ix[c]
    state = sess.run(last_state, feed_dict={x:ix, init_state: state})

def char_filter(pred):
    cache = ['!', '.', '?', '\"', ' ', ',', '%']
    if pred >='a' and pred <= 'z':
        return pred
    if pred >='A' and pred <= 'Z':
        return pred
    if pred in cache:
        return pred
    return ''

ret = prime
char = prime[-1]
num = 1000
for n in xrange(num):
    ix = np.zeros((1,1))
    ix[0,0] = char_to_ix[char]
    probsval, state = sess.run([probs, last_state], 
                               feed_dict={x:ix, init_state:state})

    if np.random.rand(1) > 0.5:
        sample = np.argmax(probsval[0])
    else:
        sample = weighted_pick(probsval[0])

    pred = ix_to_char[sample]

    ret += char_filter(pred)
    char = ret[-1]

print ret

sampled chars:

The and the caral there turrid seing, acting, and I scapter the cardy and this tow in 
the carmant of the cardantg wit. The sesper the carmisted the camplessing unce you will 
the cardow shanded the entereled the haved is thuld a conturute of you wnocch forlack 
the dal would the cardersty a saiking the contter the man the this lackul not to the 
carster which he wat madame you was abver D he womple lead me the conleshs of the 
cardinous of the closed there obe in his had past, acking that time the cardinal same 
trrang the cant of the cat of the sir thim the deave, and the clome the ca in aftented 
of throse one and all be, and proming him are to they undersince, she dy hering thought 
o man the conter to musked the cardinal be at the porsess of a seet the for the 
cardinan and as the omemen exter the cardinal and to the wall take proming this, and 
the cands, and then? Tetter. Tno tay will see to see thougle. The chaned them then for 
a sra

reference:
https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/char_rnn_train_tutorial.ipynb
https://github.com/sherjilozair/char-rnn-tensorflow

你可能感兴趣的:(tensorflow)