Tensorflow Seq2Seq Decoder阶段Helper的实现

BasicDecoder和dynamic_decode

为了简单起见,从decode的入口dynamic_deocde函数开始分析:

 dynamic_decode(
   decoder,
   output_time_major=False,
   impute_finished=False,
   maximum_iterations=None,
   parallel_iterations=32,
   swap_memory=False,
   scope=None
   )
   decoder: BasicDecoder、BeamSearchDecoder或者自己定义的decoder类对象
   output_time_major: Python boolean值, 为False时,以batch_size为主outputs返回batch_size*time_step*...这种模式再计算的时候会添加额外的time; 为True时,outputs返回time_step*batch_size*...,这种模式计算速度会更快
   impute_finished:  Python boolean值,为True时会拷贝标记为finished的batch实体的状态并将输出置零,这会导致每个time step计算更慢,但是能确保最终状态和输出具有正确的值,使得程序运行更稳定,并忽略标记finished的time step
   maximum_iterations: 最大解码步数,一般训练设置为decoder_inputs_length,预测时设置一个想要的最大序列长度即可。程序会在产生<eos>或者到达最大步数处停止。

其实简单来讲dynamic_decode就是先执行decoder的初始化函数,对解码时刻的state等变量进行初始化,然后循环执行decoder的step函数进行多轮解码。简而言之,其函数主体就相当于一个for循环,程序主体部分是一个control_flow_ops.while_loop循环:

while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)

其中cond是循环的条件,body是循环执行的主体, 这两个都是函数,具体实现如下。loop_vars是要用到的变量,condition()和body()的参数相同且都是loop_vars。但一般condition()函数中只用到个别参数用来判断循环是否结束,大部分参数都是body中才会用到。parallel_iterations是并行执行循环的个数。condition()函数其实就是看finished是否全部为True,而body()函数也就是执行了decoder.step(time, inputs, state)这句代码之后一系列的赋值和判断。

# 循环条件
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                finished, unused_sequence_lengths):
    return math_ops.logical_not(math_ops.reduce_all(finished))

# 循环执行的主体
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
    ## 1. 调用step函数得到下一时刻的输出、状态、并得到下一时刻输入(由helper得到)和是否完成变量decoder_finished
  	(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state)
    ## 2. 根据decoder_finished和time是否已经大于maximum_iterations综合判断解码是否结束
  	next_finished = math_ops.logical_or(decoder_finished, finished)
  	if maximum_iterations is not None:
    	next_finished = math_ops.logical_or(
        	next_finished, time + 1 >= maximum_iterations)
  	next_sequence_lengths = array_ops.where(
      	math_ops.logical_and(math_ops.logical_not(finished), next_finished),
      	array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
      	sequence_lengths)

  	nest.assert_same_structure(state, decoder_state)
  	nest.assert_same_structure(outputs_ta, next_outputs)
  	nest.assert_same_structure(inputs, next_inputs),
    ## 3. 如果设置了impute_finished为真,在程序结束时将next_outputs置为零,不让其进行反向传播。并对decoder_state进行拷贝得到下一时刻状态. 所以这里如果设置为true,会浪费一些时间,但是精度会更高
  	if impute_finished:
    	emit = nest.map_structure(lambda out, zero: array_ops.where(finished, zero, out), 			next_outputs, zero_outputs)
  	else:
    	emit = next_outputs

    # Copy through states past finish
    def _maybe_copy_state(new, cur):
      	# TensorArrays and scalar states get passed through.
      	if isinstance(cur, tensor_array_ops.TensorArray):
        	pass_through = True
      	else:
        	new.set_shape(cur.shape)
        	pass_through = (new.shape.ndims == 0)
      	return new if pass_through else array_ops.where(finished, cur, new)
	## 4. 判断输入是否完成
    if impute_finished:
      	next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
    else:
      	next_state = decoder_state
    ## 5. 返回结果outputs_ta
    outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit)
    return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
# 调用上面定义的cond和body进行循环解码
res = control_flow_ops.while_loop(condition, body,
      loop_vars=[initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ],
      parallel_iterations=parallel_iterations, swap_memory=swap_memory)

那么decoder.step()函数究竟做了哪些工作呢?其实你可以把它理解为RNNCell.cell滚动了一次。只不过考虑到解码,会在此基础上添加一些诸如使用helper得到输出答案,并将其转换为下一时刻输入等操作。如下所示:

def step(self, time, inputs, state, name=None):
   	with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
     	cell_outputs, cell_state = self._cell(inputs, state)
     	if self._output_layer is not None:
        	# 如果设置了output层,将cell的输出进行映射
        	cell_outputs = self._output_layer(cell_outputs)
     	# 根据输出结果,选出想要的答案,比如说贪婪法选择概率最大的单词,Scheduled使用某种概率分布进行采样等等
     	sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state)
    # 得到输出结果将其转化为下一时刻输入。train的时候就是decoder_inputs的下一时刻,预测的时候将选出的单词进行embedding即可
   (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids)
   outputs = BasicDecoderOutput(cell_outputs, sample_ids) # nameTulpe,将其一起作为outputs变量
   return (outputs, next_state, next_inputs, finished)

helper文件的TrainingHelper和GreedyEmbeddingHelper以及CustomHelper

接下来我们就看一下不同的helper类的initialize,sample和next_inputs三个函数分别干了什么。

TrainingHelper

一般用于训练阶段Decoder解码,辅助Decoder解码过程

# 初始化finished以及initial_inputs
def initialize(self, name=None):
	with ops.name_scope(name, "TrainingHelperInitialize"):
	    finished = math_ops.equal(0, self._sequence_length)
	    all_finished = math_ops.reduce_all(finished)
	    next_inputs = control_flow_ops.cond(
		    all_finished, lambda: self._zero_inputs,
		    lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
	return (finished, next_inputs)

def sample(self, time, outputs, name=None, **unused_kwargs):
	with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
	 # 使用argmax函数取出outputs中的最大值
		sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
	return sample_ids

def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
	"""next_inputs_fn for TrainingHelper."""
	with ops.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]):
		next_time = time + 1
		# 再下一时刻的step小于decoder_sequence_length时,其bool值为False
		finished = (next_time >= self._sequence_length)
		all_finished = math_ops.reduce_all(finished)
		# 直接从decode_inputs中读取下一个时刻的值作为下一时刻的解码输入
		def read_from_ta(inp):
			return inp.read(next_time)
		next_inputs = control_flow_ops.cond(
			all_finished, lambda: self._zero_inputs,
			lambda: nest.map_structure(read_from_ta, self._input_tas))
	return (finished, next_inputs, state)

GreedyEmbeddingHelper

一般用于预测阶段的Decoder解码,使用Greedy算法进行计算, 辅助Decoder解码过程

# 初始化finished以及initial_inputs
def initialize(self, name=None):
	# 初始化 all False at the initial step
   	finished = array_ops.tile([False], [self._batch_size])
   	return (finished, self._start_inputs)
   	
def sample(self, time, outputs, state, name=None):
   	del time, state  # unused by sample_fn
   	if not isinstance(outputs, ops.Tensor):
     	raise TypeError("Expected outputs to be a single Tensor, got: %s" %type(outputs))
   	# 使用argmax函数取出outputs中的最大值
   	sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
   	return sample_ids

def next_inputs(self, time, outputs, state, sample_ids, name=None):
   	del time, outputs  # unused by next_inputs_fn
   	finished = math_ops.equal(sample_ids, self._end_token)
   	all_finished = math_ops.reduce_all(finished)
   	# 将sample_ids通过embedding_lookup(embedding, ids)转化成下一时刻输入的词向量
   	next_inputs = control_flow_ops.cond(all_finished,
       		# If we're finished, the next_inputs value doesn't matter
       		lambda: self._start_inputs,
       		lambda: self._embedding_fn(sample_ids))
   return (finished, next_inputs, state)

CustomHelper

一般来说我们使用CustomHelper是为了解决Decoder阶要使用上一个时间节点的输出的需求,这就造成了不可能提前包装好,即标准的动态rnn相当于: s i = f ( s i − 1 , x i ) s_i = f(s_{i-1}, x_i) si=f(si1,xi);但如果这个函数的参数需要扩充,比如我们做的: s i = f ( s i − 1 , y i − 1 , h i , c i ) s_i = f(s_{i-1}, y_{i-1}, h_i, c_i) si=f(si1,yi1,hi,ci)

于是我们需要Hack:使用tf.contrib.seq2seq.CustomHelper,传入三个函数:
initial_fn():第一个时间点的输入。
sample_fn():如何从logit到确定的某个固定的类别id。
next_inputs_fn():确定一般的时间点的输入。

# 传给CustomHelper的三个函数
# 初始化inputs和finished
def initial_fn():
    # all False at the initial step
    initial_finished = (0 >= self.decoder_seq_length)
    return (initial_finished, self.start_inputs)

def sample_fn(time, outputs, state):
    # del time, state  # unused by sample_fn
    # 使用argmax函数取出outputs中的最大值
    sample_ids = tf.cast(tf.argmax(outputs, axis=-1), dtype=tf.int32)
    return sample_ids

def next_inputs_fn(time, outputs, state, sample_ids):
    # 上一个时间节点上的输出类别,获取embedding再作为下一个时间节点的输入
    next_input = tf.nn.embedding_lookup(decoder_embedding, sample_ids)
    time += 1 # next time 为输入time + 1,否者会造成logits多出一个time step 
    # this operation produces boolean tensor of [batch_size]
    elements_finished = (time >= self.decoder_seq_length)
    # -> boolean scalar,标记整个batch已经结束
    all_finished = tf.reduce_all(elements_finished)  
    # If finished, the next_inputs value doesn't matter
    next_inputs = tf.cond(all_finished, lambda: self.start_inputs, lambda: next_input)
    return elements_finished, next_inputs, state

# 自定义helper使用
helper = CustomHelper(initial_fn, sample_fn, next_inputs_fn)

你可能感兴趣的:(深度学习,Tensorflow,Seq2Seq)