本文衔接tf.contrib.seq2seq.dynamic_decode源码分析以及tf.contrib.seq2seq.BasicDecoder源码解析。除了TrainingHelper后面还会介绍到GreedyEmbeddingHelper。
TrainingHelper代码
首先先要明确的是在训练阶段,我们需要给于解码器句子,并得到相对应的输出随后进行训练。
class TrainingHelper(Helper):
"""A helper for use during training. Only reads inputs.
Returned sample_ids are the argmax of the RNN output logits.
"""
def __init__(self, inputs, sequence_length, time_major=False, name=None):
"""Initializer.
Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
name: Name scope for any created operations.
Raises:
ValueError: if `sequence_length` is not a 1D tensor.
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
self._input_tas = nest.map_structure(_unstack_ta, inputs)
self._sequence_length = ops.convert_to_tensor(
sequence_length, name="sequence_length")
if self._sequence_length.get_shape().ndims != 1:
raise ValueError(
"Expected sequence_length to be a vector, but received shape: %s" %
self._sequence_length.get_shape())
self._zero_inputs = nest.map_structure(
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
所以TrainingHelper接收的参数主要有一个大小为[batch_size, seqlen, embed_size]的输入inputs;以及每个句子的真实长度sequence_length,是一个[batch_size]的向量;time_major为真则把seqlen作为第一维。注意下sequence_length是一个batch_size大小的数组,指明了每个句子的真实长度(因为有些长度是padding的)。
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
这里主要是初始化,给于外界第一个输入数据。
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
在TrainingHelper的next_inputs中,我们每次读取都是inputs中time+1的数据,并且返回给之前的数据。注意这里有个finished,这里意思就是判断当前time是否大于seqlen,如果大于说明这个输出应该为0向量。
def sample(self, time, outputs, name=None, **unused_kwargs):
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
在这里也实现了一个sample函数,主要是用来采样的,取输出概率最大的词作为当前的输出词。其实在TrainingHelper比较关心next_inputs,而在推理阶段,我们更关注这个sample函数。
可以看到,在Seq2seq提供了各种各样的Helper,在这个Helper中基本都提供了一个next_inputs和sample函数,但是在训练阶段我们更关注于next_inputs这个函数,因为我们只是想要输出然后用于后面的训练。