常见 tf.contrib.seq2seq API

BasicDecoder类和dynamic_decode

decoder文件中定义了Decoder抽象类和dynamic_decode函数,dynamic_decode可以视为整个解码过程的入口,需要传入的参数就是Decoder的一个实例,他会动态的调用Decoder的step函数按步执行decode,可以理解为Decoder类定义了单步解码(根据输入求出输出,并将该输出当做下一时刻输入)

basic_decoder文件定义了一个基本的Decoder类实例BasicDecoder,其初始化函数:

def __init__(self, cell, helper, initial_state, output_layer=None):

需要传入的参数就是cell类型、helper类型、初始化状态(encoder的最后一个隐层状态)、输出层(输出映射层,将rnn_size转化为vocab_size维)

AttentionWrapper

AttentionWrapper在原本RNNCell的基础上在封装一层attention

# 分为三步,第一步是定义attention机制,第二步是定义要是用的基础的RNNCell,第三步是使用AttentionWrapper进行封装

    #定义要使用的attention机制。

    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)

    #attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)

    # 定义decoder阶段要是用的LSTMCell,然后为其封装attention wrapper

    decoder_cell = self._create_rnn_cell()

    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper')

Helper类

helper其实就是decode阶段如何根据预测结果得到下一时刻的输入,比如训练过程中应该直接使用上一时刻的真实值作为下一时刻输入,预测过程中可以使用贪婪搜索选择概率最大的那个值作为下一时刻等等。所以Helper也就可以大致分为训练时helper和预测时helper两种

“TrainingHelper”:训练过程中最常使用的Helper,下一时刻输入就是上一时刻target的真实值

“GreedyEmbeddingHelper”:预测阶段最常使用的Helper,下一时刻输入是上一时刻概率最大的单词通过embedding之后的向量

#分为四步,第一步是定义cell类型,第二步是定义helper类型,第三步是定义BasicDecoder类实例,第四步是调用dynamic_decode函数进行解码

    decoder_cell = ***(上面的代码)

    training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,

                                                        sequence_length=self.decoder_targets_length,

                                                        time_major=False, name='training_helper')

    training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,

                                                      initial_state=decoder_initial_state, output_layer=output_layer)
    #调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id)
    # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss
    # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案
    decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder, impute_finished=True,

                                                        maximum_iterations=self.max_target_sequence_length)

Beam search decoder类

BeamSearchDecoder类,其实是一个Decoder的实例,跟BasicDecoder在一个等级上,但是二者又存在着不同,因为BasicDecoder需要指定helper参数,也就是指定decode阶段如何根据上一时刻输出获得下一时刻输入。但是BeamSearchDecoder不需要,因为其在内部实现了beam_search的功能,也就包含了helper的效果。

所以解码器有两种方式,直接调用BeamSearchDecoder,或者使用调用GreedyEmbeddingHelper+BasicDecoder的组合进行贪婪式解码

#分为三步,第一步是定义cell,第二步是定义BeamSearchDecoder,第三步是调用dynamic_decode函数进行解码

    docoder_cell = ***(上面代码)
    if self.beam_search:  
        inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding,

                                                            start_tokens=start_tokens, end_token=end_token,

                                                            initial_state=decoder_initial_state,

                                                            beam_width=self.beam_size,

                                                            output_layer=output_layer)
    else:
         decoding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embedding,
                                                                                   start_tokens=start_tokens, end_token=end_token)
         inference_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=decoding_helper,
                                                                            initial_state=decoder_initial_state,
                                                                            output_layer=output_layer)

    decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,

                                                    maximum_iterations=self.max_target_sequence_length)

你可能感兴趣的:(常见 tf.contrib.seq2seq API)