本文衔接TrainingHelper,也可以衔接BasicDecoder。先说明一下,GreedyEmbeddingHelper主要作用是接收开始符,然后生成指定长度大小的句子。
GreedyEmbeddingHelper代码传送门
class GreedyEmbeddingHelper(Helper):
"""A helper for use during inference.
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
"""
def __init__(self, embedding, start_tokens, end_token):
"""Initializer.
Args:
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`. The returned tensor
will be passed to the decoder input.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
Raises:
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
scalar.
"""
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._batch_size = array_ops.size(start_tokens)
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
在GreedyEmbeddingHelper初始阶段,接收一个embedding矩阵,以便后面的embedding_lookup。可以注意到在TrainingHelper并不需要这个,是因为在训练阶段,我们给TrainingHelper的就是[batch_size, seq_len, embed_size]的输入,已经是词向量了。而在推理阶段,我们只给了一个开始符,给了我们需要的句子长度,所以我们在输出一个词的时候还需要进行embedding_lookup成词向量作为下一个时刻的输入。
def initialize(self, name=None):
finished = array_ops.tile([False], [self._batch_size])
return (finished, self._start_inputs)
第一个输入,在TrainingHelper的第一个输入是inputs[0],而这里的第一个输入是开始符向量(注意开始符是一个[batch_size]的向量,里面的元素不一定都一样。因为有时候我们可能在生成到一半的句子中才开始推理,这时候的第一个开始符生成一半句子的最后一个词)。当然,这里的finished肯定是都是False的。
def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
return sample_ids
这里是采样的意思,判断一个词根据什么情况来在这里,Greedy是贪婪的意思,也就是这个采样遵循贪心算法,选取最大概率输出对应词作为采样的词。
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
但是GreedyEmbeddingHelper其实也关注next_inputs,因为上一个采样的词需要当成当前的输入。
Helper类型很多,SampleEmbeddingHelper,CustomHelper,ScheduledEmbeddingTrainingHelper,ScheduledOutputTrainingHelper,InferenceHelper,其实大多大同小异,学会了训练阶段的Helper和推理阶段的Helper的典型,也就是上面两个,就可以触类旁通。
全部的代码在Helper.py这里,有需要延伸的可以继续看看。