为了简单起见,从decode的入口dynamic_deocde函数开始分析:
dynamic_decode(
decoder,
output_time_major=False,
impute_finished=False,
maximum_iterations=None,
parallel_iterations=32,
swap_memory=False,
scope=None
)
decoder: BasicDecoder、BeamSearchDecoder或者自己定义的decoder类对象
output_time_major: Python boolean值, 为False时,以batch_size为主outputs返回batch_size*time_step*...这种模式再计算的时候会添加额外的time; 为True时,outputs返回time_step*batch_size*...,这种模式计算速度会更快
impute_finished: Python boolean值,为True时会拷贝标记为finished的batch实体的状态并将输出置零,这会导致每个time step计算更慢,但是能确保最终状态和输出具有正确的值,使得程序运行更稳定,并忽略标记finished的time step
maximum_iterations: 最大解码步数,一般训练设置为decoder_inputs_length,预测时设置一个想要的最大序列长度即可。程序会在产生<eos>或者到达最大步数处停止。
其实简单来讲dynamic_decode就是先执行decoder的初始化函数,对解码时刻的state等变量进行初始化,然后循环执行decoder的step函数进行多轮解码。简而言之,其函数主体就相当于一个for循环,程序主体部分是一个control_flow_ops.while_loop循环:
while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)
其中cond是循环的条件,body是循环执行的主体, 这两个都是函数,具体实现如下。loop_vars是要用到的变量,condition()和body()的参数相同且都是loop_vars。但一般condition()函数中只用到个别参数用来判断循环是否结束,大部分参数都是body中才会用到。parallel_iterations是并行执行循环的个数。condition()函数其实就是看finished是否全部为True,而body()函数也就是执行了decoder.step(time, inputs, state)
这句代码之后一系列的赋值和判断。
# 循环条件
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths):
return math_ops.logical_not(math_ops.reduce_all(finished))
# 循环执行的主体
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
## 1. 调用step函数得到下一时刻的输出、状态、并得到下一时刻输入(由helper得到)和是否完成变量decoder_finished
(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state)
## 2. 根据decoder_finished和time是否已经大于maximum_iterations综合判断解码是否结束
next_finished = math_ops.logical_or(decoder_finished, finished)
if maximum_iterations is not None:
next_finished = math_ops.logical_or(
next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
math_ops.logical_and(math_ops.logical_not(finished), next_finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
nest.assert_same_structure(inputs, next_inputs),
## 3. 如果设置了impute_finished为真,在程序结束时将next_outputs置为零,不让其进行反向传播。并对decoder_state进行拷贝得到下一时刻状态. 所以这里如果设置为true,会浪费一些时间,但是精度会更高
if impute_finished:
emit = nest.map_structure(lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs)
else:
emit = next_outputs
# Copy through states past finish
def _maybe_copy_state(new, cur):
# TensorArrays and scalar states get passed through.
if isinstance(cur, tensor_array_ops.TensorArray):
pass_through = True
else:
new.set_shape(cur.shape)
pass_through = (new.shape.ndims == 0)
return new if pass_through else array_ops.where(finished, cur, new)
## 4. 判断输入是否完成
if impute_finished:
next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
else:
next_state = decoder_state
## 5. 返回结果outputs_ta
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
# 调用上面定义的cond和body进行循环解码
res = control_flow_ops.while_loop(condition, body,
loop_vars=[initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ],
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
那么decoder.step()函数究竟做了哪些工作呢?其实你可以把它理解为RNNCell.cell滚动了一次。只不过考虑到解码,会在此基础上添加一些诸如使用helper得到输出答案,并将其转换为下一时刻输入等操作。如下所示:
def step(self, time, inputs, state, name=None):
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
# 如果设置了output层,将cell的输出进行映射
cell_outputs = self._output_layer(cell_outputs)
# 根据输出结果,选出想要的答案,比如说贪婪法选择概率最大的单词,Scheduled使用某种概率分布进行采样等等
sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state)
# 得到输出结果将其转化为下一时刻输入。train的时候就是decoder_inputs的下一时刻,预测的时候将选出的单词进行embedding即可
(finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids) # nameTulpe,将其一起作为outputs变量
return (outputs, next_state, next_inputs, finished)
接下来我们就看一下不同的helper类的initialize,sample和next_inputs三个函数分别干了什么。
一般用于训练阶段Decoder解码,辅助Decoder解码过程
# 初始化finished以及initial_inputs
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
def sample(self, time, outputs, name=None, **unused_kwargs):
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
# 使用argmax函数取出outputs中的最大值
sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]):
next_time = time + 1
# 再下一时刻的step小于decoder_sequence_length时,其bool值为False
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
# 直接从decode_inputs中读取下一个时刻的值作为下一时刻的解码输入
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
一般用于预测阶段的Decoder解码,使用Greedy算法进行计算, 辅助Decoder解码过程
# 初始化finished以及initial_inputs
def initialize(self, name=None):
# 初始化 all False at the initial step
finished = array_ops.tile([False], [self._batch_size])
return (finished, self._start_inputs)
def sample(self, time, outputs, state, name=None):
del time, state # unused by sample_fn
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %type(outputs))
# 使用argmax函数取出outputs中的最大值
sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, sample_ids, name=None):
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
# 将sample_ids通过embedding_lookup(embedding, ids)转化成下一时刻输入的词向量
next_inputs = control_flow_ops.cond(all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
一般来说我们使用CustomHelper是为了解决Decoder阶要使用上一个时间节点的输出的需求,这就造成了不可能提前包装好,即标准的动态rnn相当于: s i = f ( s i − 1 , x i ) s_i = f(s_{i-1}, x_i) si=f(si−1,xi);但如果这个函数的参数需要扩充,比如我们做的: s i = f ( s i − 1 , y i − 1 , h i , c i ) s_i = f(s_{i-1}, y_{i-1}, h_i, c_i) si=f(si−1,yi−1,hi,ci)。
于是我们需要Hack:使用tf.contrib.seq2seq.CustomHelper,传入三个函数:
initial_fn():第一个时间点的输入。
sample_fn():如何从logit到确定的某个固定的类别id。
next_inputs_fn():确定一般的时间点的输入。
# 传给CustomHelper的三个函数
# 初始化inputs和finished
def initial_fn():
# all False at the initial step
initial_finished = (0 >= self.decoder_seq_length)
return (initial_finished, self.start_inputs)
def sample_fn(time, outputs, state):
# del time, state # unused by sample_fn
# 使用argmax函数取出outputs中的最大值
sample_ids = tf.cast(tf.argmax(outputs, axis=-1), dtype=tf.int32)
return sample_ids
def next_inputs_fn(time, outputs, state, sample_ids):
# 上一个时间节点上的输出类别,获取embedding再作为下一个时间节点的输入
next_input = tf.nn.embedding_lookup(decoder_embedding, sample_ids)
time += 1 # next time 为输入time + 1,否者会造成logits多出一个time step
# this operation produces boolean tensor of [batch_size]
elements_finished = (time >= self.decoder_seq_length)
# -> boolean scalar,标记整个batch已经结束
all_finished = tf.reduce_all(elements_finished)
# If finished, the next_inputs value doesn't matter
next_inputs = tf.cond(all_finished, lambda: self.start_inputs, lambda: next_input)
return elements_finished, next_inputs, state
# 自定义helper使用
helper = CustomHelper(initial_fn, sample_fn, next_inputs_fn)