1.该模型是一个Seq2Seq的模型:
输入:(0,1)序列;例如x = (1,1,1,0,0,0,1,1)
标签:输出(0,1)序列右移若干位后的序列;例如将x右移2位后,y=(0,0,1,1,1,0,0,0)
2.该模型的作用是给定一个(0,1)序列,预测其右移若干位后的序列。
3.该模型是一个深度的LSTM,其将3个LSTM堆叠到一起。
4.该模型使用dynamic_rnn来实现动态LSTM
import numpy as np
import tensorflow as tf
参数
num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length
num_layers = 3
数据生成函数
def generateData():
x = np.array(np.random.choice(2,total_series_length,p=[0.5,0.5]))
y = np.roll(x,echo_step)
y[:echo_step] = 0
x = x.reshape((batch_size,-1))
y = y.reshape((batch_size,-1))
return x,y
定义计算图
X = tf.placeholder(tf.float32,[batch_size,truncated_backprop_length])
Y = tf.placeholder(tf.int32,[batch_size,truncated_backprop_length])
# 2:一个是output,一个是hidden state
init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
# 按不同层将init_state进行拆分
state_per_layer_list = tf.unstack(init_state,axis=0)
# 将不同层的init_state组成tuple
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
for idx in range(num_layers)])
W = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)
# 获得LSTMCell,其中该Cell的输出维度由state_size指定
def get_lstm_cell():
return tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
# 包括num_layers个Cell的list
cells = [get_lstm_cell() for _ in range(num_layers)]
# 将多个LSTMCell stack在一起作为基本cell
stacked_cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
# 由于cell是3个LSTMCell stack在一起的,因此初始state包含三个LSTMStateTuple的tuple
# 这个stack LSTM的长度由X决定
# tf.nn.dynamic_rnn返回的outputs是一个Tensor,而tf.nn.static_rnn返回的outputs是list
# current_state是一个长为3的tuple
outputs, current_state = tf.nn.dynamic_rnn(stacked_cell, tf.expand_dims(X, -1), initial_state=rnn_tuple_state)
# reshape前,outputs:(batch_size,truncated_backprop_length,state_size)==(5,15,4)
# reshape后,outputs:(batch_size*truncated_backprop_length,state_size)==(75,4)
outputs = tf.reshape(outputs, [-1, state_size])
# logits:(batch_size*truncated_backprop_length,state_size)==(75,2)
logits = tf.matmul(outputs, W) + b #Broadcasted addition
labels = tf.reshape(Y, [-1])
# 将logits重塑为(batch_size,truncated_backprop_length,num_classes),并按照axis=1进行拆分
logits_series = tf.unstack(tf.reshape(logits, [batch_size, truncated_backprop_length, num_classes]), axis=1)
predictions_series = [tf.nn.softmax(logit) for logit in logits_series]
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
loss_list = []
for epoch_idx in range(num_epochs):
x,y = generateData()
_current_state = np.zeros((num_layers, 2, batch_size, state_size))
print("New data, epoch", epoch_idx)
for batch_idx in range(num_batches):
start_idx = batch_idx * truncated_backprop_length
end_idx = start_idx + truncated_backprop_length
batchX = x[:,start_idx:end_idx]
batchY = y[:,start_idx:end_idx]
_total_loss, _train_step, _current_state, _predictions_series = sess.run(
[total_loss, train_step, current_state, predictions_series],
feed_dict={
X: batchX,
Y: batchY,
init_state: _current_state
})
loss_list.append(_total_loss)
if batch_idx%100 == 0:
print("Step",batch_idx, "Batch loss", _total_loss)
WARNING:tensorflow:From /root/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
New data, epoch 0
Step 0 Batch loss 0.6948361
Step 100 Batch loss 0.6896327
Step 200 Batch loss 0.6090469
Step 300 Batch loss 0.4460393
Step 400 Batch loss 0.5721256
Step 500 Batch loss 0.14829773
Step 600 Batch loss 0.0068182033
......
New data, epoch 99
Step 0 Batch loss 0.29179895
Step 100 Batch loss 3.272678e-05
Step 200 Batch loss 3.4703935e-05
Step 300 Batch loss 1.8189417e-05
Step 400 Batch loss 1.6067552e-05
Step 500 Batch loss 1.7221488e-05
Step 600 Batch loss 1.9481524e-05