tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper

简介

最近在用tensorflow搞seq2seq,遇到了不少问题。首先就是tf.contrib.seq2seqtf.contrib.legacy_seq2seq到底用哪个?查最新版api可以发现tf.contrib.legacy_seq2seq已经被抛弃,这时你会想,选tf.contrib.seq2seq不就好了。然而,悲剧的是github、csdn上的例子全是tf.contrib.legacy_seq2seq的例子,而且运行 tensorflow/models下tf.contrib.legacy_seq2seq的例子会报错can’t pickle _thread.lock objects。本着迎难而上的准则,开始探索tf.contrib.seq2seq,顺便记录我踩过的坑。为了书写简单,在接下来的介绍中,若不加前缀,则默认指tf.contrib.seq2seq,例如GreedyEmbeddingHelpertf.contrib.seq2seq.GreedyEmbeddingHelper

系统环境

>>> import sys
>>> import tensorflow as tf
>>> print(sys.version)
3.6.0 |Anaconda 4.3.1 (64-bit)| (default, Dec 23 2016, 12:22:00) \n[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
>>> print(tf.__version__)
1.3.0

GreedyEmbeddingHelper

本节主要记录我在使用GreedyEmbeddingHelper踩过的坑。

介绍

介绍GreedyEmbeddingHelper要从Helper开始,因为所有“…..Helper”都来自于它。Helper是seq2seq中decoder采样的接口,且其实例对象会被BasicDecoder调用。简单而言就是,开发者把decoder采样的过程抽象出来,方便后来的人使用(大神们,牛逼!)。以Helper为基础,tensorflow中延伸了很多类,结构如下所示

tensorflow(一):tf.contrib.seq2seq.GreedyEmbeddingHelper_第1张图片

其中红线表示继承关系。以TrainingHelper为代表的类控制训练过程,包括Scheduled和非Schedule两种方式。以GreedyEmbeddingHelper为代表的类用于贪心编码,一般用于预测。

报错

在使用GreedyEmbeddingHelper的过程中遇到的问题是:

Traceback (most recent call last):
  File "tutorial#2.py", line 217, in 
    model = Model(vocab_size)
  File "tutorial#2.py", line 172, in __init__
    output_time_major=False)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
    swap_memory=swap_memory)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2775, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2604, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2554, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 139, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 450, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 401, in call
    concat = _linear([inputs, h], 4 * self._num_units, True)
  File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1027, in _linear
    "but saw %s" % (shape, shape[1]))
ValueError: linear expects shape[1] to be provided for shape (3, ?), but saw ?

原因

在Debug中,错误定位在172行,对应代码

pred_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
    decoder=decoder, maximum_iterations=30,
    impute_finished=False,
    output_time_major=False)

刚开始一直查tensorflow的API,是不是函数定义有问题,查了很久发现定义并没有错误。这个时候,我开始怀疑decoder是不是有问题,顺着这个思路我开始检查decoder

decoder = tf.contrib.seq2seq.BasicDecoder(
    cell=self.decoder_cell,
    helper=pred_helper,
    initial_state=encoder_final_state)

还是先查api,确定定义没有问题,然后通过排除法猜测是pred_helper出了问题,定义代码如下:

self.decoder_inputs = tf.placeholder(
    shape=(None, None), dtype=tf.int32, name='decoder_inputs')
pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
    self.decoder_inputs,
    start_tokens=tf.fill([batch_size], SOS_ID),
    end_token=EOS_ID)

熟悉API的同学一眼就能发现,GreedyEmbeddingHelper第一个参数是embedding并不是decoder的输入,所以导致tensor的shape一直不对。其实从seq2seq的理论角度,这里显然也不应该是decoder_inputs,因为预测时目标是不知道的。通过这个Bug,收获良多,哈哈…..

你可能感兴趣的:(tensorflow)