(八)sequence to sequence —2

实现softmax_loss_function部分

基于tensorflow1.4 Seq2seq的实现

import helpers
import tensorflow as tf
from tensorflow.contrib import seq2seq,rnn

tf.__version__
'1.4.0'
tf.reset_default_graph()
sess = tf.InteractiveSession()
PAD = 0
EOS = 1


vocab_size = 10
input_embedding_size = 20
encoder_hidden_units = 25

decoder_hidden_units = encoder_hidden_units

import helpers as data_helpers
batch_size = 50

# 一个generator,每次产生一个minibatch的随机样本

batches = data_helpers.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)

print('产生%d个长度不一(最短3,最长8)的sequences, 其中前十个是:' % batch_size)
for seq in next(batches)[:min(batch_size, 10)]:
    print(seq)
产生50个长度不一(最短3,最长8)的sequences, 其中前十个是:
[5, 3, 9, 9, 5]
[3, 6, 7, 6]
[4, 3, 5, 7, 6, 3]
[6, 3, 4, 3, 6]
[2, 3, 9, 3, 3]
[4, 8, 2, 4, 9, 7, 8, 7]
[6, 9, 2, 7, 3, 3]
[9, 2, 7]
[2, 2, 5, 2, 5, 2]
[5, 4, 2, 7, 8, 5]

1.使用seq2seq库实现seq2seq模型(encoder 部分不变)

tf.reset_default_graph()
sess = tf.InteractiveSession()
mode = tf.contrib.learn.ModeKeys.TRAIN

with tf.name_scope('minibatch'):
    encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
    
    encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length')
    
    decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
    
    decoder_inputs = tf.placeholder(shape=(None, None),dtype=tf.int32,name='decoder_inputs')
    
    #decoder_inputs_length和decoder_targets_length是一样的
    decoder_inputs_length = tf.placeholder(shape=(None,),
                                            dtype=tf.int32,
                                            name='decoder_inputs_length')
    

    
def _create_rnn_cell():
    def single_rnn_cell(encoder_hidden_units):
        # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell
        # 的列表中最终模型会发生错误
        single_cell = rnn.LSTMCell(encoder_hidden_units)
        #添加dropout
        single_cell = rnn.DropoutWrapper(single_cell, output_keep_prob=0.5)
        return single_cell
            #列表中每个元素都是调用single_rnn_cell函数
            #cell = rnn.MultiRNNCell([single_rnn_cell() for _ in range(self.num_layers)])
    cell = rnn.MultiRNNCell([single_rnn_cell(encoder_hidden_units) for _ in range(1)])
    return cell 

2.Candidate Sampling实现

实际过程中vocab_size过大,在计算loss生产one_hot的时候传统的softmax由于要计算每一个类的logits就会有问题

具体参照这篇论文https://arxiv.org/pdf/1409.0473v7.pdf

num_samples = 5

w_sample = tf.get_variable('proj_w', [vocab_size,encoder_hidden_units])
#w_t = tf.transpose(w)
b_sample = tf.get_variable('proj_b', [vocab_size])

    # 调用sampled_softmax_loss函数计算sample loss,这样可以节省计算时间
def sample_loss(logits, labels):
    labels = tf.cast(labels, tf.int64)
    labels = tf.reshape(labels, [-1, 1])
    logits = tf.cast(logits, tf.float32)
    #logits = tf.reshape(labels, [-1, 1])
    #decoder_logits_train = tf.unstack(logits,axis=1)
    #decoder_targets = tf.unstack(labels,axis=1)
    return tf.cast(tf.nn.sampled_softmax_loss(w_sample, b_sample, labels=labels, inputs=logits,
                                      num_sampled=num_samples, num_classes=vocab_size),tf.float32)
softmax_loss_function = sample_loss

tensorflow seq2seq.sequence_loss接口:

seq2seq_loss接口.png

tensorflow tf.nn.sampled_softmax_loss接口:

sampled_softmax_loss接口.png

1.定义encoder部分

with tf.variable_scope('encoder'):
    # 创建LSTMCell
    encoder_cell = _create_rnn_cell()
    # 构建embedding矩阵,encoder和decoder公用该词向量矩阵
    embedding = tf.get_variable('embedding', [vocab_size,input_embedding_size])
    encoder_inputs_embedded = tf.nn.embedding_lookup(embedding,encoder_inputs)
    # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量。
    # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size,
    # encoder_state用于decoder的初始化状态,batch_size*rnn_szie
    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded,
                                                       sequence_length=encoder_inputs_length,
                                                       dtype=tf.float32)

这里我们使用双向 dynamic_rnn:

RNN-bidirectional.png

图片来自于Colah的blog

2.定义decoder 部分(暂时不添加attention部分)

此处不添加output_layer,在sample_loss那里有一层

with tf.variable_scope('decoder'):
    decoder_cell = _create_rnn_cell()
    
    #定义decoder的初始状态
    decoder_initial_state = encoder_state
    
    #定义output_layer
    #output_layer = tf.layers.Dense(encoder_hidden_units,kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
    
    decoder_inputs_embedded = tf.nn.embedding_lookup(embedding, decoder_inputs)
    
    # 训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能
    training_helper = seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,
                                                        sequence_length=decoder_inputs_length,
                                                        time_major=False, name='training_helper')
    training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,
                                                       initial_state=decoder_initial_state,
                                                       output_layer=None)
    
    # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id)
    # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss
    # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案
    max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
    decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=training_decoder,
                                                          impute_finished=True,
                                                          maximum_iterations=max_target_sequence_length)
    
    #创建一个与decoder_outputs.rnn_output一样的tensor给decoder_logits_train
    decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
    sample_id = decoder_outputs.sample_id
    
    # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。用一个sequence_mask的例子来说明起作用
    #  tf.sequence_mask([1, 3, 2], 5)
    #  [[True, False, False, False, False],
    #  [True, True, True, False, False],
    #  [True, True, False, False, False]]
    max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
    mask = tf.sequence_mask(decoder_inputs_length,max_target_sequence_length, dtype=tf.float32, name='masks')
    print('\t%s' % repr(decoder_logits_train))
    print('\t%s' % repr(decoder_targets))
    print('\t%s' % repr(sample_id))
    loss = seq2seq.sequence_loss(logits=decoder_logits_train,targets=decoder_targets, 
                                weights=mask,softmax_loss_function=softmax_loss_function)
    
    
    
train_op = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
sess.run(tf.global_variables_initializer())
def next_feed():
    batch = next(batches)
    
    encoder_inputs_, encoder_inputs_length_ = data_helpers.batch(batch)
    decoder_targets_, decoder_targets_length_ = data_helpers.batch(
        [(sequence) + [EOS] for sequence in batch]
    )
    decoder_inputs_, decoder_inputs_length_ = data_helpers.batch(
        [[EOS] + (sequence) for sequence in batch]
    )
    
    # 在feedDict里面,key可以是一个Tensor
    return {
        encoder_inputs: encoder_inputs_.T,
        decoder_inputs: decoder_inputs_.T,
        decoder_targets: decoder_targets_.T,
        encoder_inputs_length: encoder_inputs_length_,
        decoder_inputs_length: decoder_inputs_length_
    }
x = next_feed()
print('encoder_inputs:')
print(x[encoder_inputs][0,:])
print('encoder_inputs_length:')
print(x[encoder_inputs_length][0])
print('decoder_inputs:')
print(x[decoder_inputs][0,:])
print('decoder_inputs_length:')
print(x[decoder_inputs_length][0])
print('decoder_targets:')
print(x[decoder_targets][0,:])
encoder_inputs:
[6 9 7 7 3 7 0 0]
encoder_inputs_length:
6
decoder_inputs:
[1 6 9 7 7 3 7 0 0]
decoder_inputs_length:
7
decoder_targets:
[6 9 7 7 3 7 1 0 0]
loss_track = []
max_batches = 3001
batches_in_epoch = 100

try:
    # 一个epoch的learning
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        
        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_outputs.sample_id, fd)
            for i, (inp, pred) in enumerate(zip(fd[encoder_inputs], predict_)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
        
except KeyboardInterrupt:
    print('training interrupted')
batch 0
  minibatch loss: 1.7552857398986816
  sample 1:
    input     > [4 3 2 0 0 0 0 0]
    predicted > [ 8 21  3 14  0  0  0  0  0]
  sample 2:
    input     > [2 8 8 2 4 0 0 0]
    predicted > [21 21 21  3  4  3  0  0  0]
  sample 3:
    input     > [8 5 4 4 4 0 0 0]
    predicted > [24 21 24 15 21 18  0  0  0]

batch 100
  minibatch loss: 1.4532331228256226
  sample 1:
    input     > [4 3 2 5 9 4 8 0]
    predicted > [ 0 12 13 13 12 13 13 11  0]
  sample 2:
    input     > [5 8 7 5 4 0 0 0]
    predicted > [24 24 21 18 13 13  0  0  0]
  sample 3:
    input     > [2 4 2 5 9 4 8 8]
    predicted > [24 12  0 13 13 12 13 11 13]

batch 200
  minibatch loss: 1.1243680715560913
  sample 1:
    input     > [9 3 9 2 5 4 0 0]
    predicted > [ 3 12  3 13 21 13 21  0  0]
  sample 2:
    input     > [9 8 5 2 5 4 0 0]
    predicted > [15 13 21  7 13 16 11  0  0]
  sample 3:
    input     > [3 8 5 3 8 2 6 5]
    predicted > [12 12 15  3  7  3  7 18 18]

batch 300
  minibatch loss: 1.1811888217926025
  sample 1:
    input     > [2 2 2 7 7 0 0 0]
    predicted > [ 0  9 12 12  6 18  0  0  0]
  sample 2:
    input     > [9 8 4 4 8 5 2 6]
    predicted > [ 0  0  0  0  0 17 11 13  2]
  sample 3:
    input     > [2 7 4 4 9 0 0 0]
    predicted > [ 9 13 12 12 17  6  0  0  0]

batch 400
  minibatch loss: 1.0878313779830933
  sample 1:
    input     > [9 9 3 8 3 0 0 0]
    predicted > [ 9 12  6 18 18 21  0  0  0]
  sample 2:
    input     > [5 6 2 6 8 0 0 0]
    predicted > [15  0 13 12  7 18  0  0  0]
  sample 3:
    input     > [2 9 4 7 6 0 0 0]
    predicted > [ 8 13 13  0 12 18  0  0  0]

batch 500
  minibatch loss: 0.8977712988853455
  sample 1:
    input     > [6 7 2 5 7 9 0 0]
    predicted > [ 5 12 12 15 15 21 18  0  0]
  sample 2:
    input     > [7 4 3 3 4 8 5 0]
    predicted > [24  9  9  9  0 12 13 13  0]
  sample 3:
    input     > [2 9 8 0 0 0 0 0]
    predicted > [15 21 18 21  0  0  0  0  0]

batch 600
  minibatch loss: 0.9903306365013123
  sample 1:
    input     > [9 5 3 8 4 5 2 0]
    predicted > [15 15 24 12 15  3  3  7  0]
  sample 2:
    input     > [5 2 3 2 5 3 3 8]
    predicted > [ 9  3 10 24  9 24  6 11 11]
  sample 3:
    input     > [6 3 4 6 0 0 0 0]
    predicted > [ 8 12 12 13 11  0  0  0  0]

batch 700
  minibatch loss: 1.0557962656021118
  sample 1:
    input     > [4 5 9 0 0 0 0 0]
    predicted > [15  7 10 18  0  0  0  0  0]
  sample 2:
    input     > [2 4 7 7 5 8 5 0]
    predicted > [ 9  9 11 12 13 13 13 11  0]
  sample 3:
    input     > [2 4 9 0 0 0 0 0]
    predicted > [21 13 18 21  0  0  0  0  0]

batch 800
  minibatch loss: 0.7463603019714355
  sample 1:
    input     > [8 3 8 7 0 0 0 0]
    predicted > [24 24 12 17 18  0  0  0  0]
  sample 2:
    input     > [7 2 5 6 9 0 0 0]
    predicted > [ 9 12  7  7 18  7  0  0  0]
  sample 3:
    input     > [2 8 5 9 2 2 7 0]
    predicted > [18 21  3  3 18  0  7 16  0]

batch 900
  minibatch loss: 0.57407546043396
  sample 1:
    input     > [2 2 3 0 0 0 0 0]
    predicted > [ 3 12 18 18  0  0  0  0  0]
  sample 2:
    input     > [6 9 2 9 0 0 0 0]
    predicted > [ 3 21 21 10 18  0  0  0  0]
  sample 3:
    input     > [2 5 4 3 0 0 0 0]
    predicted > [ 9  0 12 24 11  0  0  0  0]

batch 1000
  minibatch loss: 0.5782870650291443
  sample 1:
    input     > [3 4 8 5 5 4 8 8]
    predicted > [ 5  5  9 20 13 17 16 11 11]
  sample 2:
    input     > [4 8 3 5 3 7 0 0]
    predicted > [ 5  9 10  1  7  9 11  0  0]
  sample 3:
    input     > [4 6 5 4 4 2 9 6]
    predicted > [ 5  5  5 12 12 12 12  7  7]

batch 1100
  minibatch loss: 0.5811575055122375
  sample 1:
    input     > [9 4 8 7 6 4 7 3]
    predicted > [ 5 17  5 17 12 12 12 17 13]
  sample 2:
    input     > [3 3 9 5 2 5 0 0]
    predicted > [10 12 15 15 12  7 13  0  0]
  sample 3:
    input     > [4 5 6 2 3 4 5 0]
    predicted > [ 5  5  0 24  9 11 13 13  0]

batch 1200
  minibatch loss: 0.6396902203559875
  sample 1:
    input     > [2 3 5 4 4 9 0 0]
    predicted > [ 3  9  5  0 12 21 16  0  0]
  sample 2:
    input     > [2 8 9 9 5 6 7 5]
    predicted > [18  0 15 12  0 11  0 21 11]
  sample 3:
    input     > [4 6 9 7 4 0 0 0]
    predicted > [ 0  0 11 17 11 13  0  0  0]

batch 1300
  minibatch loss: 0.6583465337753296
  sample 1:
    input     > [3 6 8 9 8 7 6 0]
    predicted > [12  8  9  9  9 12  7 18  0]
  sample 2:
    input     > [9 8 7 8 0 0 0 0]
    predicted > [ 5 12 12 13 21  0  0  0  0]
  sample 3:
    input     > [6 9 9 4 9 5 2 7]
    predicted > [21 15 15 12 15  7 18 18  6]

batch 1400
  minibatch loss: 0.6087316870689392
  sample 1:
    input     > [9 6 8 2 3 9 9 0]
    predicted > [15 12 12  3  3 15  7 11  0]
  sample 2:
    input     > [2 8 7 7 8 0 0 0]
    predicted > [ 5  9  9 17 18 16  0  0  0]
  sample 3:
    input     > [4 2 7 5 4 3 9 0]
    predicted > [ 5  9 10 12  9 10 21  6  0]

batch 1500
  minibatch loss: 0.717374324798584
  sample 1:
    input     > [3 5 3 6 3 3 6 0]
    predicted > [ 8  6 17  8 12  0 22 22  0]
  sample 2:
    input     > [8 5 6 5 2 6 0 0]
    predicted > [ 8  8  8 15 12 15  7  0  0]
  sample 3:
    input     > [7 6 4 6 8 8 7 0]
    predicted > [ 0 13  0  0 13 12 22 13  0]

batch 1600
  minibatch loss: 0.4989091753959656
  sample 1:
    input     > [6 5 6 4 5 8 0 0]
    predicted > [ 1 15  5  0 11 13  7  0  0]
  sample 2:
    input     > [4 9 4 6 6 0 0 0]
    predicted > [ 5 12  0 15  0  2  0  0  0]
  sample 3:
    input     > [2 8 3 8 0 0 0 0]
    predicted > [21 12 13 13 21  0  0  0  0]

batch 1700
  minibatch loss: 0.5497146248817444
  sample 1:
    input     > [9 6 7 6 0 0 0 0]
    predicted > [ 8  0  0  0 11  0  0  0  0]
  sample 2:
    input     > [7 5 9 7 0 0 0 0]
    predicted > [15  1  0 11 11  0  0  0  0]
  sample 3:
    input     > [6 5 2 2 0 0 0 0]
    predicted > [21  3  3 13 13  0  0  0  0]

batch 1800
  minibatch loss: 0.606837272644043
  sample 1:
    input     > [8 6 2 6 0 0 0 0]
    predicted > [13 12 12 22 13  0  0  0  0]
  sample 2:
    input     > [3 6 4 5 5 9 0 0]
    predicted > [15  1  1 20  1 10 21  0  0]
  sample 3:
    input     > [8 7 2 4 0 0 0 0]
    predicted > [ 9 12  9 11  7  0  0  0  0]

batch 1900
  minibatch loss: 0.5147760510444641
  sample 1:
    input     > [7 4 9 2 0 0 0 0]
    predicted > [12 12 21 18 18  0  0  0  0]
  sample 2:
    input     > [3 3 6 3 6 2 5 0]
    predicted > [ 6  8  8  0 13  7  7 22  0]
  sample 3:
    input     > [8 4 9 0 0 0 0 0]
    predicted > [ 5 12  7 11  0  0  0  0  0]

batch 2000
  minibatch loss: 0.33478930592536926
  sample 1:
    input     > [4 8 9 2 9 9 0 0]
    predicted > [ 5 12 21 21 10  7 11  0  0]
  sample 2:
    input     > [8 5 2 8 6 2 3 0]
    predicted > [20  5 13 13 13 12  7 14  0]
  sample 3:
    input     > [4 9 5 4 5 5 4 2]
    predicted > [ 5  5  5 20 12 12 12 18  6]

batch 2100
  minibatch loss: 0.47149085998535156
  sample 1:
    input     > [2 3 3 8 9 3 4 0]
    predicted > [17  9 17  9  9 10 12  6  0]
  sample 2:
    input     > [7 2 8 6 8 9 2 0]
    predicted > [18 12 12 12 12  7 21 18  0]
  sample 3:
    input     > [9 3 6 6 9 3 0 0]
    predicted > [15  0 15 15 12 10 10  0  0]

batch 2200
  minibatch loss: 0.35661277174949646
  sample 1:
    input     > [3 6 2 2 7 4 5 3]
    predicted > [18  0 12 12  0 12 15 22 22]
  sample 2:
    input     > [4 8 7 5 6 0 0 0]
    predicted > [12 12 22 15 22  2  0  0  0]
  sample 3:
    input     > [9 3 3 2 4 0 0 0]
    predicted > [10  9 12 12 13 14  0  0  0]

batch 2300
  minibatch loss: 0.490815132856369
  sample 1:
    input     > [4 2 9 5 2 0 0 0]
    predicted > [ 0  3 12 15  7  7  0  0  0]
  sample 2:
    input     > [2 6 6 8 0 0 0 0]
    predicted > [15  8 18 12 13  0  0  0  0]
  sample 3:
    input     > [7 2 9 3 3 7 0 0]
    predicted > [12  3 24 10  6 22 18  0  0]

batch 2400
  minibatch loss: 0.45908093452453613
  sample 1:
    input     > [3 7 7 7 0 0 0 0]
    predicted > [24 18  0  0 18  0  0  0  0]
  sample 2:
    input     > [9 9 3 0 0 0 0 0]
    predicted > [ 3 18  6 18  0  0  0  0  0]
  sample 3:
    input     > [6 8 3 4 9 8 0 0]
    predicted > [13  5 10  5 12 12 20  0  0]

batch 2500
  minibatch loss: 0.3008703291416168
  sample 1:
    input     > [7 6 2 0 0 0 0 0]
    predicted > [12  8 12  7  0  0  0  0  0]
  sample 2:
    input     > [9 5 9 9 8 7 0 0]
    predicted > [15  3 21 12 12 24  7  0  0]
  sample 3:
    input     > [4 8 3 3 0 0 0 0]
    predicted > [17 12  6 17 11  0  0  0  0]

batch 2600
  minibatch loss: 0.5544220805168152
  sample 1:
    input     > [6 7 4 3 2 5 2 0]
    predicted > [13 13 13  3  7  3 12 11  0]
  sample 2:
    input     > [2 4 5 5 7 9 7 0]
    predicted > [ 9  1  5  0 10 18 10 18  0]
  sample 3:
    input     > [2 3 2 8 3 6 0 0]
    predicted > [ 3 12 21 13 12  7 18  0  0]

batch 2700
  minibatch loss: 0.3938275873661041
  sample 1:
    input     > [5 5 8 8 0 0 0 0]
    predicted > [ 1 20 13 16 13  0  0  0  0]
  sample 2:
    input     > [8 9 8 2 2 5 7 0]
    predicted > [21 12 12 18 24 12 18 18  0]
  sample 3:
    input     > [2 8 2 5 0 0 0 0]
    predicted > [12 12  3  7  7  0  0  0  0]

batch 2800
  minibatch loss: 0.42015042901039124
  sample 1:
    input     > [9 8 2 7 8 0 0 0]
    predicted > [21 12 12 17  7 11  0  0  0]
  sample 2:
    input     > [6 9 4 0 0 0 0 0]
    predicted > [15 12 11  4  0  0  0  0  0]
  sample 3:
    input     > [9 6 5 7 5 6 0 0]
    predicted > [15 15  1  0 20  0 11  0  0]

batch 2900
  minibatch loss: 0.30204904079437256
  sample 1:
    input     > [7 9 5 8 6 5 0 0]
    predicted > [10  1  1 13 12  7  7  0  0]
  sample 2:
    input     > [3 5 4 0 0 0 0 0]
    predicted > [ 3 20 16 14  0  0  0  0  0]
  sample 3:
    input     > [2 6 8 2 0 0 0 0]
    predicted > [21 12 12 12  7  0  0  0  0]

batch 3000
  minibatch loss: 0.2870854437351227
  sample 1:
    input     > [5 8 7 4 8 6 5 4]
    predicted > [24 13 12 12 12 15 12 12 21]
  sample 2:
    input     > [2 4 5 5 6 6 2 0]
    predicted > [20 20 15  0 15 18 21 22  0]
  sample 3:
    input     > [6 5 4 6 7 2 0 0]
    predicted > [15 15  0 13 12 18 18  0  0]

你可能感兴趣的:((八)sequence to sequence —2)