最近在用tensorflow搞seq2seq,遇到了不少问题。首先就是tf.contrib.seq2seq
和tf.contrib.legacy_seq2seq
到底用哪个?查最新版api可以发现tf.contrib.legacy_seq2seq
已经被抛弃,这时你会想,选tf.contrib.seq2seq
不就好了。然而,悲剧的是github、csdn上的例子全是tf.contrib.legacy_seq2seq
的例子,而且运行 tensorflow/models下tf.contrib.legacy_seq2seq
的例子会报错can’t pickle _thread.lock objects。本着迎难而上的准则,开始探索tf.contrib.seq2seq
,顺便记录我踩过的坑。为了书写简单,在接下来的介绍中,若不加前缀,则默认指tf.contrib.seq2seq
,例如GreedyEmbeddingHelper
指tf.contrib.seq2seq.GreedyEmbeddingHelper
。
>>> import sys
>>> import tensorflow as tf
>>> print(sys.version)
3.6.0 |Anaconda 4.3.1 (64-bit)| (default, Dec 23 2016, 12:22:00) \n[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
>>> print(tf.__version__)
1.3.0
本节主要记录我在使用GreedyEmbeddingHelper
踩过的坑。
介绍GreedyEmbeddingHelper
要从Helper
开始,因为所有“…..Helper”都来自于它。Helper
是seq2seq中decoder采样的接口,且其实例对象会被BasicDecoder
调用。简单而言就是,开发者把decoder采样的过程抽象出来,方便后来的人使用(大神们,牛逼!)。以Helper
为基础,tensorflow中延伸了很多类,结构如下所示
其中红线表示继承关系。以TrainingHelper
为代表的类控制训练过程,包括Scheduled和非Schedule两种方式。以GreedyEmbeddingHelper
为代表的类用于贪心编码,一般用于预测。
在使用GreedyEmbeddingHelper
的过程中遇到的问题是:
Traceback (most recent call last):
File "tutorial#2.py", line 217, in
model = Model(vocab_size)
File "tutorial#2.py", line 172, in __init__
output_time_major=False)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
swap_memory=swap_memory)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2775, in while_loop
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2604, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2554, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
decoder_finished) = decoder.step(time, inputs, state)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 139, in step
cell_outputs, cell_state = self._cell(inputs, state)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 450, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 401, in call
concat = _linear([inputs, h], 4 * self._num_units, True)
File "/data0/ads_dm/yewen1/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1027, in _linear
"but saw %s" % (shape, shape[1]))
ValueError: linear expects shape[1] to be provided for shape (3, ?), but saw ?
在Debug中,错误定位在172行,对应代码
pred_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder=decoder, maximum_iterations=30,
impute_finished=False,
output_time_major=False)
刚开始一直查tensorflow的API,是不是函数定义有问题,查了很久发现定义并没有错误。这个时候,我开始怀疑decoder
是不是有问题,顺着这个思路我开始检查decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=self.decoder_cell,
helper=pred_helper,
initial_state=encoder_final_state)
还是先查api,确定定义没有问题,然后通过排除法猜测是pred_helper
出了问题,定义代码如下:
self.decoder_inputs = tf.placeholder(
shape=(None, None), dtype=tf.int32, name='decoder_inputs')
pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
self.decoder_inputs,
start_tokens=tf.fill([batch_size], SOS_ID),
end_token=EOS_ID)
熟悉API的同学一眼就能发现,GreedyEmbeddingHelper
第一个参数是embedding
并不是decoder的输入,所以导致tensor的shape一直不对。其实从seq2seq的理论角度,这里显然也不应该是decoder_inputs
,因为预测时目标是不知道的。通过这个Bug,收获良多,哈哈…..