(八)sequence to sequence —6

最后一关:

Encoder:多层双向lstm

Attention机制

decoder:动态实现bi-directional_dynamic_rnn

基于tensorflow1.4 Seq2seq的实现

import helpers
import tensorflow as tf
from tensorflow.python.util import nest
from tensorflow.contrib import seq2seq,rnn

tf.__version__

tf.reset_default_graph()
sess = tf.InteractiveSession()

PAD = 0
EOS = 1


vocab_size = 10
input_embedding_size = 20
encoder_hidden_units = 25

decoder_hidden_units = encoder_hidden_units

import helpers as data_helpers
batch_size = 10

# 一个generator,每次产生一个minibatch的随机样本

batches = data_helpers.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)

print('产生%d个长度不一(最短3,最长8)的sequences, 其中前十个是:' % batch_size)
for seq in next(batches)[:min(batch_size, 10)]:
    print(seq)
    
tf.reset_default_graph()
sess = tf.InteractiveSession()
mode = tf.contrib.learn.ModeKeys.TRAIN
产生10个长度不一(最短3,最长8)的sequences, 其中前十个是:
[8, 9, 2, 8, 5, 5]
[5, 5, 4, 5, 5, 3, 2]
[6, 8, 8, 9, 6, 2]
[3, 2, 8, 7, 7, 5]
[6, 2, 9, 3, 3, 8, 9]
[7, 6, 9]
[4, 4, 9, 4, 2, 4, 5]
[3, 6, 4, 3, 3]
[6, 6, 7, 7]
[3, 3, 9, 3, 7]

1.使用seq2seq库实现seq2seq模型

with tf.name_scope('minibatch'):
    encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
    
    encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length')
    
    decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
    
    """
    decoder_inputs = tf.placeholder(shape=(None, None),dtype=tf.int32,name='decoder_inputs')
    
    #decoder_inputs_length和decoder_targets_length是一样的
    decoder_inputs_length = tf.placeholder(shape=(None,),
                                            dtype=tf.int32,
                                            name='decoder_inputs_length')
    """
# 构建embedding矩阵,encoder和decoder公用该词向量矩阵
embedding = tf.get_variable('embedding', [vocab_size,input_embedding_size])
encoder_inputs_embedded = tf.nn.embedding_lookup(embedding,encoder_inputs)

#fw_cell = bw_cell =  rnn.LSTMCell(encoder_hidden_units)

定义encoder,两层双向lstm

_inputs=encoder_inputs_embedded
for _ in range(2):
    #为什么在这加个variable_scope,被逼的,tf在rnn_cell的__call__中非要搞一个命名空间检查
    #恶心的很.如果不在这加的话,会报错的.
    with tf.variable_scope(None, default_name="bidirectional-rnn"):
        rnn_cell_bw =  rnn_cell_fw = rnn.LSTMCell(encoder_hidden_units)
        #rnn_cell_bw = rnn.LSTMCell(encoder_hidden_units)
        #initial_state_fw = rnn_cell_fw.zero_state(batch_size, dtype=tf.float32)
        #initial_state_bw = rnn_cell_bw.zero_state(batch_size, dtype=tf.float32)
        ((encoder_fw_outputs,encoder_bw_outputs),(encoder_fw_final_state,encoder_bw_final_state))\
        = tf.nn.bidirectional_dynamic_rnn(cell_fw=rnn_cell_fw,
                                              cell_bw=rnn_cell_bw, 
                                              inputs=_inputs, 
                                              sequence_length=encoder_inputs_length,
                                              dtype=tf.float32)
        _inputs = tf.concat((encoder_fw_outputs,encoder_bw_outputs), 2)
#取最后一层的 final_state    
encoder_final_state_h = tf.concat((encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
encoder_final_state_c = tf.concat((encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
encoder_final_state = rnn.LSTMStateTuple(c=encoder_final_state_c, h=encoder_final_state_h)
encoder_final_output = _inputs
encoder_final_state
LSTMStateTuple(c=, h=)
encoder_final_output

5.定义decoder 部分

这里的Decoder中,每个输入除了上一个时间节点的输出以外,还有对应时间节点的Encoder的输出,以及attention的context。

nct-seq2seq.png

常用的Helper:

TrainingHelper:适用于训练的helper。
InferenceHelper:适用于测试的helper。
GreedyEmbeddingHelper:适用于测试中采用Greedy策略sample的helper。
CustomHelper:用户自定义的helper。

这里着重介绍CustomHelper,要传入三个函数作为参数:
initialize_fn:返回finished,next_inputs。其中finished不是scala,是一个一维向量。这个函数即获取第一个时间节点的输入。
sample_fn:接收参数(time, outputs, state) 返回sample_ids。即,根据每个cell的输出,如何sample。
next_inputs_fn:接收参数(time, outputs, state, sample_ids) 返回 (finished, next_inputs, next_state),根据上一个时刻的输出,决定下一个时刻的输入。

# 传给CustomHelper的三个函数
decoder_lengths = encoder_inputs_length+3#这里设置decoder_lengths比encoder_inputs_length长3个
eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS')
pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

eos_step_embedded = tf.nn.embedding_lookup(embedding, eos_time_slice)
pad_step_embedded = tf.nn.embedding_lookup(embedding, pad_time_slice)

def initial_fn():
    initial_elements_finished = (0 >= decoder_lengths)  # all False at the initial step
    #initial_input = tf.concat((sos_step_embedded, encoder_outputs[0]), 1)
    initial_input = eos_step_embedded
    return initial_elements_finished, initial_input

def sample_fn(time, outputs, state):
    # 选择logit最大的下标作为sample
    prediction_id = tf.to_int32(tf.argmax(outputs, axis=1))
    return prediction_id

def next_inputs_fn(time, outputs, state, sample_ids):
    # 上一个时间节点上的输出类别,获取embedding再作为下一个时间节点的输入
    pred_embedding = tf.nn.embedding_lookup(embedding, sample_ids)
    # 输入是h_i+o_{i-1}+c_i
    #next_input = tf.concat((pred_embedding, encoder_final_output[time]), 1)
    next_input = pred_embedding
    elements_finished = (time >= decoder_lengths)  # this operation produces boolean tensor of [batch_size]
    all_finished = tf.reduce_all(elements_finished)  # -> boolean scalar
    next_inputs = tf.cond(all_finished, lambda: pad_step_embedded, lambda: next_input)
    next_state = state
    return elements_finished, next_inputs, next_state
 my_helper = tf.contrib.seq2seq.CustomHelper(initial_fn, sample_fn, next_inputs_fn)

定义Attention机制

Attention.png
attention_mechanism = seq2seq.BahdanauAttention(num_units=2*encoder_hidden_units,
                                                memory=encoder_final_output,
                                                memory_sequence_length=encoder_inputs_length)
decoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units * 2)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                    decoder_cell, attention_mechanism, attention_layer_size=encoder_hidden_units)
output_layer = tf.layers.Dense(vocab_size,kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))

decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32).clone(
                cell_state=encoder_final_state)
training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=my_helper,
                                                                   initial_state=decoder_initial_state,
                                                                   output_layer=output_layer)
max_target_sequence_length = tf.reduce_max(decoder_lengths, name='max_target_len')
decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=training_decoder,impute_finished=True,
                                               maximum_iterations=max_target_sequence_length)
decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
sample_id = decoder_outputs.sample_id

定义计算loss的mask

mask = tf.sequence_mask(decoder_lengths,max_target_sequence_length, dtype=tf.float32, name='masks')
print('\t%s' % repr(decoder_logits_train))
print('\t%s' % repr(decoder_targets))
print('\t%s' % repr(sample_id))
loss = seq2seq.sequence_loss(logits=decoder_logits_train,targets=decoder_targets, weights=mask)
    
    
    
train_op = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
sess.run(tf.global_variables_initializer())
def next_feed():
    batch = next(batches)
    
    encoder_inputs_, encoder_inputs_length_ = data_helpers.batch(batch)
    decoder_targets_, _ = data_helpers.batch(
        [(sequence) + [EOS]+ [PAD] * 2 for sequence in batch]#decoder_lengths比encoder length长3
    )
    
    # 在feedDict里面,key可以是一个Tensor
    return {
        encoder_inputs: encoder_inputs_.T,
        decoder_targets: decoder_targets_.T,
        encoder_inputs_length: encoder_inputs_length_,     
    }
x = next_feed()
print('encoder_inputs:')
print(x[encoder_inputs][0,:])
print('encoder_inputs_length:')
print(x[encoder_inputs_length][0])
print('decoder_targets:')
print(x[decoder_targets][0,:])
encoder_inputs:
[6 4 3 8 6 7 2 6]
encoder_inputs_length:
8
decoder_targets:
[6 4 3 8 6 7 2 6 1 0 0]
loss_track = []
max_batches = 3001
batches_in_epoch = 100

try:
    # 一个epoch的learning
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        
        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_outputs.sample_id, fd)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs], predict_)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
        
except KeyboardInterrupt:
    print('training interrupted')
batch 0
  minibatch loss: 2.29461669921875
  sample 1:
    input     > [5 5 9 4 6 4 4 5]
    predicted > [3 3 2 3 2 1 0 0 1 0 0]
  sample 2:
    input     > [5 9 7 4 6 7 9 3]
    predicted > [0 0 0 0 0 1 0 0 0 1 0]
  sample 3:
    input     > [6 6 4 2 8 0 0 0]
    predicted > [0 1 0 1 0 0 1 0 1 0 0]

batch 100
  minibatch loss: 1.6221174001693726
  sample 1:
    input     > [3 4 5 2 0 0 0 0]
    predicted > [3 3 3 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [3 8 4 9 6 0 0 0]
    predicted > [3 3 3 3 1 0 0 0 0 0 0]
  sample 3:
    input     > [6 4 2 8 0 0 0 0]
    predicted > [3 3 3 1 0 0 0 0 0 0 0]

batch 200
  minibatch loss: 1.388022780418396
  sample 1:
    input     > [9 5 3 3 4 3 0 0]
    predicted > [3 3 3 3 3 3 1 1 0 0 0]
  sample 2:
    input     > [9 6 2 0 0 0 0 0]
    predicted > [6 2 6 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [7 8 4 9 6 8 0 0]
    predicted > [3 3 3 3 3 4 1 0 0 0 0]

batch 300
  minibatch loss: 1.270598292350769
  sample 1:
    input     > [2 8 7 0 0 0 0 0]
    predicted > [7 7 7 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [4 7 3 0 0 0 0 0]
    predicted > [3 7 1 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [9 3 8 0 0 0 0 0]
    predicted > [3 8 8 1 0 0 0 0 0 0 0]

batch 400
  minibatch loss: 1.1839914321899414
  sample 1:
    input     > [2 7 2 3 6 3 8 5]
    predicted > [2 2 2 2 2 8 8 8 1 0 0]
  sample 2:
    input     > [7 6 7 7 0 0 0 0]
    predicted > [7 7 7 7 1 0 0 0 0 0 0]
  sample 3:
    input     > [2 3 7 2 3 0 0 0]
    predicted > [2 8 8 8 8 1 0 0 0 0 0]

batch 500
  minibatch loss: 1.178087592124939
  sample 1:
    input     > [7 5 8 8 6 5 0 0]
    predicted > [8 8 8 8 8 8 1 0 0 0 0]
  sample 2:
    input     > [2 8 4 4 6 7 0 0]
    predicted > [2 2 4 4 4 4 1 0 0 0 0]
  sample 3:
    input     > [6 3 2 6 2 7 3 0]
    predicted > [7 6 6 6 6 7 7 1 0 0 0]

batch 600
  minibatch loss: 0.895336925983429
  sample 1:
    input     > [8 7 7 4 0 0 0 0]
    predicted > [7 7 7 7 1 0 0 0 0 0 0]
  sample 2:
    input     > [5 9 9 7 6 8 2 0]
    predicted > [3 3 7 7 7 7 2 1 0 0 0]
  sample 3:
    input     > [4 7 2 0 0 0 0 0]
    predicted > [2 2 2 1 0 0 0 0 0 0 0]

batch 700
  minibatch loss: 0.29237034916877747
  sample 1:
    input     > [6 8 2 0 0 0 0 0]
    predicted > [6 8 2 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [5 8 7 6 7 4 5 6]
    predicted > [5 8 7 6 7 4 5 6 1 0 0]
  sample 3:
    input     > [9 6 5 0 0 0 0 0]
    predicted > [9 6 5 1 0 0 0 0 0 0 0]

batch 800
  minibatch loss: 0.1250164806842804
  sample 1:
    input     > [6 7 3 7 5 9 6 8]
    predicted > [6 7 3 7 5 9 6 8 1 0 0]
  sample 2:
    input     > [9 7 7 0 0 0 0 0]
    predicted > [9 7 7 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [7 3 4 7 4 8 0 0]
    predicted > [7 3 4 7 4 8 1 0 0 0 0]

batch 900
  minibatch loss: 0.04103495180606842
  sample 1:
    input     > [4 7 5 2 5 2 6 0]
    predicted > [4 7 5 2 5 2 6 1 0 0 0]
  sample 2:
    input     > [5 7 8 0 0 0 0 0]
    predicted > [5 7 8 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [4 3 2 6 0 0 0 0]
    predicted > [4 3 2 6 1 0 0 0 0 0 0]

batch 1000
  minibatch loss: 0.02061438001692295
  sample 1:
    input     > [5 9 9 6 4 6 9 0]
    predicted > [5 9 9 6 4 6 9 1 0 0 0]
  sample 2:
    input     > [3 7 7 2 3 9 0 0]
    predicted > [3 7 7 2 3 9 1 0 0 0 0]
  sample 3:
    input     > [4 2 4 6 8 7 6 3]
    predicted > [4 2 4 6 8 7 6 3 1 0 0]

batch 1100
  minibatch loss: 0.018973074853420258
  sample 1:
    input     > [5 3 9 5 7 2 5 6]
    predicted > [5 3 9 5 7 2 5 6 1 0 0]
  sample 2:
    input     > [2 7 4 8 8 9 0 0]
    predicted > [2 7 4 8 8 9 1 0 0 0 0]
  sample 3:
    input     > [3 8 3 7 5 0 0 0]
    predicted > [3 8 3 7 5 1 0 0 0 0 0]

batch 1200
  minibatch loss: 0.01220142375677824
  sample 1:
    input     > [6 8 7 0 0 0 0 0]
    predicted > [6 8 7 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [5 6 6 0 0 0 0 0]
    predicted > [5 6 6 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [6 5 4 0 0 0 0 0]
    predicted > [6 5 4 1 0 0 0 0 0 0 0]

batch 1300
  minibatch loss: 0.008632375858724117
  sample 1:
    input     > [8 9 3 4 9 2 0 0]
    predicted > [8 9 3 4 9 2 1 0 0 0 0]
  sample 2:
    input     > [5 9 2 4 7 9 3 0]
    predicted > [5 9 2 4 7 9 3 1 0 0 0]
  sample 3:
    input     > [5 6 6 0 0 0 0 0]
    predicted > [5 6 6 1 0 0 0 0 0 0 0]

batch 1400
  minibatch loss: 0.00644361088052392
  sample 1:
    input     > [8 2 8 7 0 0 0 0]
    predicted > [8 2 8 7 1 0 0 0 0 0 0]
  sample 2:
    input     > [9 5 9 0 0 0 0 0]
    predicted > [9 5 9 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [6 5 6 9 4 6 8 4]
    predicted > [6 5 6 9 4 6 8 4 1 0 0]

batch 1500
  minibatch loss: 0.005314035806804895
  sample 1:
    input     > [7 2 5 9 8 4 4 0]
    predicted > [7 2 5 9 8 4 4 1 0 0 0]
  sample 2:
    input     > [4 9 9 4 3 0 0 0]
    predicted > [4 9 9 4 3 1 0 0 0 0 0]
  sample 3:
    input     > [8 6 4 5 8 0 0 0]
    predicted > [8 6 4 5 8 1 0 0 0 0 0]

batch 1600
  minibatch loss: 0.005184624344110489
  sample 1:
    input     > [3 4 2 0 0 0 0 0]
    predicted > [3 4 2 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [6 8 5 3 5 3 0 0]
    predicted > [6 8 5 3 5 3 1 0 0 0 0]
  sample 3:
    input     > [9 2 4 8 8 0 0 0]
    predicted > [9 2 4 8 8 1 0 0 0 0 0]

batch 1700
  minibatch loss: 0.0034769047051668167
  sample 1:
    input     > [8 5 8 7 4 9 9]
    predicted > [8 5 8 7 4 9 9 1 0 0]
  sample 2:
    input     > [3 5 8 0 0 0 0]
    predicted > [3 5 8 1 0 0 0 0 0 0]
  sample 3:
    input     > [4 4 7 5 2 7 0]
    predicted > [4 4 7 5 2 7 1 0 0 0]

batch 1800
  minibatch loss: 0.002784730400890112
  sample 1:
    input     > [2 5 7 4 0 0 0 0]
    predicted > [2 5 7 4 1 0 0 0 0 0 0]
  sample 2:
    input     > [5 4 7 2 5 0 0 0]
    predicted > [5 4 7 2 5 1 0 0 0 0 0]
  sample 3:
    input     > [4 9 5 4 5 4 0 0]
    predicted > [4 9 5 4 5 4 1 0 0 0 0]

batch 1900
  minibatch loss: 0.002491097431629896
  sample 1:
    input     > [4 9 9 5 9 0 0 0]
    predicted > [4 9 9 5 9 1 0 0 0 0 0]
  sample 2:
    input     > [2 2 4 9 7 2 8 9]
    predicted > [2 2 4 9 7 2 8 9 1 0 0]
  sample 3:
    input     > [9 8 8 2 4 0 0 0]
    predicted > [9 8 8 2 4 1 0 0 0 0 0]

batch 2000
  minibatch loss: 0.0022913815919309855
  sample 1:
    input     > [3 4 9 4 5 6 0 0]
    predicted > [3 4 9 4 5 6 1 0 0 0 0]
  sample 2:
    input     > [2 8 4 4 6 0 0 0]
    predicted > [2 8 4 4 6 1 0 0 0 0 0]
  sample 3:
    input     > [2 9 5 8 7 0 0 0]
    predicted > [2 9 5 8 7 1 0 0 0 0 0]

batch 2100
  minibatch loss: 0.001817821292206645
  sample 1:
    input     > [8 3 6 2 0 0 0 0]
    predicted > [8 3 6 2 1 0 0 0 0 0 0]
  sample 2:
    input     > [4 8 6 2 0 0 0 0]
    predicted > [4 8 6 2 1 0 0 0 0 0 0]
  sample 3:
    input     > [8 2 5 6 9 3 4 0]
    predicted > [8 2 5 6 9 3 4 1 0 0 0]

batch 2200
  minibatch loss: 0.0017419641371816397
  sample 1:
    input     > [2 2 3 0 0 0 0 0]
    predicted > [2 2 3 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [7 5 3 8 2 4 0 0]
    predicted > [7 5 3 8 2 4 1 0 0 0 0]
  sample 3:
    input     > [8 2 7 3 4 6 4 8]
    predicted > [8 2 7 3 4 6 4 8 1 0 0]

batch 2300
  minibatch loss: 0.001424401649273932
  sample 1:
    input     > [5 3 5 4 8 3 8 8]
    predicted > [5 3 5 4 8 3 8 8 1 0 0]
  sample 2:
    input     > [3 4 7 9 5 0 0 0]
    predicted > [3 4 7 9 5 1 0 0 0 0 0]
  sample 3:
    input     > [8 8 9 0 0 0 0 0]
    predicted > [8 8 9 1 0 0 0 0 0 0 0]

batch 2400
  minibatch loss: 0.0014329373370856047
  sample 1:
    input     > [6 6 2 8 7 0 0 0]
    predicted > [6 6 2 8 7 1 0 0 0 0 0]
  sample 2:
    input     > [7 4 5 5 0 0 0 0]
    predicted > [7 4 5 5 1 0 0 0 0 0 0]
  sample 3:
    input     > [8 9 5 9 6 6 7 6]
    predicted > [8 9 5 9 6 6 7 6 1 0 0]

batch 2500
  minibatch loss: 0.001270524924620986
  sample 1:
    input     > [7 7 8 9 7 9 9 5]
    predicted > [7 7 8 9 7 9 9 5 1 0 0]
  sample 2:
    input     > [6 8 4 0 0 0 0 0]
    predicted > [6 8 4 1 0 0 0 0 0 0 0]
  sample 3:
    input     > [8 5 7 3 2 0 0 0]
    predicted > [8 5 7 3 2 1 0 0 0 0 0]

batch 2600
  minibatch loss: 0.0008346579852513969
  sample 1:
    input     > [6 6 8 5 0 0 0]
    predicted > [6 6 8 5 1 0 0 0 0 0]
  sample 2:
    input     > [5 7 3 4 9 2 9]
    predicted > [5 7 3 4 9 2 9 1 0 0]
  sample 3:
    input     > [7 2 4 3 0 0 0]
    predicted > [7 2 4 3 1 0 0 0 0 0]

batch 2700
  minibatch loss: 0.0009551114053465426
  sample 1:
    input     > [4 8 4 4 3 8 5 8]
    predicted > [4 8 4 4 3 8 5 8 1 0 0]
  sample 2:
    input     > [2 3 8 8 0 0 0 0]
    predicted > [2 3 8 8 1 0 0 0 0 0 0]
  sample 3:
    input     > [2 9 5 8 7 6 3 0]
    predicted > [2 9 5 8 7 6 3 1 0 0 0]

batch 2800
  minibatch loss: 0.000799523142632097
  sample 1:
    input     > [3 8 8 0 0 0 0 0]
    predicted > [3 8 8 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [2 4 8 5 0 0 0 0]
    predicted > [2 4 8 5 1 0 0 0 0 0 0]
  sample 3:
    input     > [6 6 7 0 0 0 0 0]
    predicted > [6 6 7 1 0 0 0 0 0 0 0]

batch 2900
  minibatch loss: 0.000805568415671587
  sample 1:
    input     > [9 7 6 4 8 6 4 0]
    predicted > [9 7 6 4 8 6 4 1 0 0 0]
  sample 2:
    input     > [8 8 3 3 8 2 6 9]
    predicted > [8 8 3 3 8 2 6 9 1 0 0]
  sample 3:
    input     > [8 7 3 8 5 9 4 0]
    predicted > [8 7 3 8 5 9 4 1 0 0 0]

batch 3000
  minibatch loss: 0.000771542196162045
  sample 1:
    input     > [7 3 3 0 0 0 0 0]
    predicted > [7 3 3 1 0 0 0 0 0 0 0]
  sample 2:
    input     > [8 4 4 4 7 7 0 0]
    predicted > [8 4 4 4 7 7 1 0 0 0 0]
  sample 3:
    input     > [2 8 7 5 9 3 6 0]
    predicted > [2 8 7 5 9 3 6 1 0 0 0]
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(loss_track)
print('loss {:.4f} after {} examples (batch_size={})'.format(loss_track[-1], 
                                                             len(loss_track)*batch_size, batch_size))
loss 0.0008 after 30010 examples (batch_size=10)

[图片上传失败...(image-353872-1544603026851)]


你可能感兴趣的:((八)sequence to sequence —6)