在学习nmt源码时对数据处理部分中的bucket有一些疑惑,现以官方示例中的 “tst2012.en”作为源数据集以及目标数据集,以“vocab.en”作为对应的词表,做了一些尝试,具体如下:
#-*- coding:utf-8 -*- import tensorflow as tf from tensorflow.python.ops import lookup_ops num_threads = 4 batch_size = 128 src_max_len = tgt_max_len = 50 source_reverse = False num_buckets = 5 random_seed = 0 #词表 src_vocab_table = tgt_vocab_table = lookup_ops.index_table_from_file( "./nmt_data/vocab.en", default_value=0) #数据集的生成,每一行内容作为一个元素 src_dataset = tf.contrib.data.TextLineDataset('./nmt_data/tst2012.en') tgt_dataset = tf.contrib.data.TextLineDataset('./nmt_data/tst2012.en') # reverse_src_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( # "./nmt_data/vocab.en", default_value=0) #sos eos及unk在词表中对应的id output_buffer_size = batch_size * 1000 src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant('
')), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(' ')), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant('')), tf.int32) #合并 src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset)) #为了便于观察,不对数据集进行乱序 #src_tgt_dataset = src_tgt_dataset.shuffle( # output_buffer_size, random_seed) #将数据集中的元素以空格分成单个单词 src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_threads=num_threads, output_buffer_size=output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) #限制数据集的最大长度为50,超出的舍弃 if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_threads=num_threads, output_buffer_size=output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_threads=num_threads, output_buffer_size=output_buffer_size) if source_reverse: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.reverse(src, axis=[0]), tgt), num_threads=num_threads, output_buffer_size=output_buffer_size) #为了便于观察,不再将数据集转换成对应词表中的id # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. # src_tgt_dataset = src_tgt_dataset.map( # lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), # tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), # num_threads=num_threads, output_buffer_size=output_buffer_size) # Create a tgt_input prefixed withand a tgt_output suffixed with . # src_tgt_dataset = src_tgt_dataset.map( # lambda src, tgt: (src, # tf.concat(([tgt_sos_id], tgt), 0), # tf.concat((tgt, [tgt_eos_id]), 0)), # num_threads=num_threads, output_buffer_size=output_buffer_size) #给目标数据集分别添加起始标志和结束标志 src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([' '], tgt), 0), tf.concat((tgt, ['']), 0)), num_threads=num_threads, output_buffer_size=output_buffer_size) # Add in the word counts. Subtract one from the target to avoid counting # the target_inputtag (resp. target_output tag). #添加源数据集及目标数据集的size src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), num_threads=num_threads, output_buffer_size=output_buffer_size) #分桶并对分桶后的数据补位 def batching_func(x): return x.padded_batch( batch_size, # The first three entries are the source and target line rows; # these have unknown-length vectors. The last two entries are # the source and target row sizes; these are scalars. padded_shapes=(tf.TensorShape([None]), # src tf.TensorShape([None]), # tgt_input tf.TensorShape([None]), # tgt_output tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len # Pad the source and target sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=('<%>', # src '<&>', # tgt_input '<*>', # tgt_output 0, # src_len -- unused 0)) # tgt_len -- unused if num_buckets > 1: def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): #def key_func(unused_1, src_len, tgt_len): # Calculate bucket_width by maximum source sequence length. # Pairs with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) * bucket_width) words all go into the last bucket. if src_max_len: bucket_width = (src_max_len + num_buckets - 1) // num_buckets else: bucket_width = 10 # Bucket sentence pairs by the length of their source sentence and target # sentence. bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(num_buckets, bucket_id)) def reduce_func(unused_key, windowed_data): return batching_func(windowed_data) batched_dataset = src_tgt_dataset.group_by_window( key_func=key_func, reduce_func=reduce_func, window_size=batch_size) else: batched_dataset = batching_func(src_tgt_dataset) #迭代器,通过get_next()访问数据集中的下一个元素 batched_iter = batched_dataset.make_initializable_iterator() (src_ids, tgt_input_ids, tgt_output_ids, src_len, tgt_len) = ( batched_iter.get_next()) with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(batched_iter.initializer) try: while True: ret = sess.run( (src_ids, tgt_input_ids, tgt_output_ids, src_len, tgt_len)) # print(ret) print(ret[3]) print('\n') print(ret[4]) print("=====================================") except tf.errors.OutOfRangeError: print('end!') 运行上述代码,结果为:
[10 10 13 12 10 9 11 11 16 13 15 13 12 16 17 13 10 17 16 9 10 10 10 10
15 10 11 18 18 9 13 13 9 15 13 15 9 14 15 9 15 9 15 17 9 13 16 16
10 16 14 12 12 18 9 12 11 13 11 13 12 15 11 10 12 11 17 10 9 12 16 9
15 14 9 15 13 17 10 12 16 18 13 14 13 12 14 10 14 10 12 9 12 11 15 9
14 10 14 10 15 12 11 10 12 14 15 10 16 12 15 15 17 15 12 18 10 10 15 11
15 9 10 12 12 15 12 11]
[11 11 14 13 11 10 12 12 17 14 16 14 13 17 18 14 11 18 17 10 11 11 11 11
16 11 12 19 19 10 14 14 10 16 14 16 10 15 16 10 16 10 16 18 10 14 17 17
11 17 15 13 13 19 10 13 12 14 12 14 13 16 12 11 13 12 18 11 10 13 17 10
16 15 10 16 14 18 11 13 17 19 14 15 14 13 15 11 15 11 13 10 13 12 16 10
15 11 15 11 16 13 12 11 13 15 16 11 17 13 16 16 18 16 13 19 11 11 16 12
16 10 11 13 13 16 13 12]
=====================================
[28 20 20 24 23 21 19 28 20 22 22 20 23 19 27 28 20 23 25 20 19 20 19 26
20 23 25 19 22 21 19 25 24 22 20 23 26 26 22 24 19 23 22 27 19 23 26 21
20 24 22 28 27 21 25 25 22 27 26 25 25 28 20 22 19 20 28 28 26 23 20 23
19 21 22 23 20 23 21 28 21 25 19 25 26 19 25 23 20 20 21 22 27 20 22 26
21 19 20 24 28 20 24 25 19 19 25 19 20 23 26 21 22 26 24 20 19 25 23 25
20 19 25 28 26 25 21 25]
[29 21 21 25 24 22 20 29 21 23 23 21 24 20 28 29 21 24 26 21 20 21 20 27
21 24 26 20 23 22 20 26 25 23 21 24 27 27 23 25 20 24 23 28 20 24 27 22
21 25 23 29 28 22 26 26 23 28 27 26 26 29 21 23 20 21 29 29 27 24 21 24
20 22 23 24 21 24 22 29 22 26 20 26 27 20 26 24 21 21 22 23 28 21 23 27
22 20 21 25 29 21 25 26 20 20 26 20 21 24 27 22 23 27 25 21 20 26 24 26
21 20 26 29 27 26 22 26]
=====================================
[ 9 9 17 14 16 11 13 9 10 15 15 17 16 14 10 14 18 11 13 16 11 18 18 15
15 13 16 17 12 12 11 15 15 17 11 17 15 13 15 15 18 13 11 15 10 9 15 10
16 17 10 11 17 13 11 13 13 10 12 12 14 14 14 12 12 10 11 15 15 18 16 18
10 10 18 13 13 15 18 16 12 14 14 9 14 12 18 10 16 11 12 10 16 18 11 16
11 11 18 14 12 18 17 16 16 12 13 12 16 15 13 18 13 12 14 16 15 15 11 12
18 16 12 12 13 9 13 17]
[10 10 18 15 17 12 14 10 11 16 16 18 17 15 11 15 19 12 14 17 12 19 19 16
16 14 17 18 13 13 12 16 16 18 12 18 16 14 16 16 19 14 12 16 11 10 16 11
17 18 11 12 18 14 12 14 14 11 13 13 15 15 15 13 13 11 12 16 16 19 17 19
11 11 19 14 14 16 19 17 13 15 15 10 15 13 19 11 17 12 13 11 17 19 12 17
12 12 19 15 13 19 18 17 17 13 14 13 17 16 14 19 14 13 15 17 16 16 12 13
19 17 13 13 14 10 14 18]
=====================================
[7 5 6 5 8 7 8 4 5 5 6 4 6 7 7 6 7 7 5 4 7 6 6 7 5 8 8 6 6 4 8 8 8 7 8 5 6
7 4 8 7 7 8 7 7 6 6 6 8 8 8 8 6 6 5 8 4 8 4 6 8 6 6 8 7 7 7 7 7 5 8 8 6 7
6 6 7 8 5 8 6 7 6 5 7 6 6 6 8 8 6 8 8 5 6 3 8 7 6 8 7 5 7 8 4 3 6 5 7 6 6
8 8 8 8 7 8 5 5 5 5 5 8 7 6 6 8 7]
[8 6 7 6 9 8 9 5 6 6 7 5 7 8 8 7 8 8 6 5 8 7 7 8 6 9 9 7 7 5 9 9 9 8 9 6 7
8 5 9 8 8 9 8 8 7 7 7 9 9 9 9 7 7 6 9 5 9 5 7 9 7 7 9 8 8 8 8 8 6 9 9 7 8
7 7 8 9 6 9 7 8 7 6 8 7 7 7 9 9 7 9 9 6 7 4 9 8 7 9 8 6 8 9 5 4 7 6 8 7 7
9 9 9 9 8 9 6 6 6 6 6 9 8 7 7 9 8]
=====================================
[16 16 9 14 12 12 10 10 11 18 9 11 15 11 16 14 9 15 14 18 15 15 13 15
14 13 14 15 9 13 17 14 9 11 18 18 10 16 10 15 18 14 13 13 13 15 11 14
10 11 17 15 10 17 9 11 9 9 11 11 9 11 14 10 10 10 10 12 12 12 11 17
14 15 16 13 13 15 11 14 13 9 14 10 9 15 13 14 12 16 18 12 15 14 17 9
16 14 13 9 10 11 18 15 14 14 11 10 10 17 15 14 11 14 17 11 10 10 15 11
14 13 13 11 9 11 9 17]
[17 17 10 15 13 13 11 11 12 19 10 12 16 12 17 15 10 16 15 19 16 16 14 16
15 14 15 16 10 14 18 15 10 12 19 19 11 17 11 16 19 15 14 14 14 16 12 15
11 12 18 16 11 18 10 12 10 10 12 12 10 12 15 11 11 11 11 13 13 13 12 18
15 16 17 14 14 16 12 15 14 10 15 11 10 16 14 15 13 17 19 13 16 15 18 10
17 15 14 10 11 12 19 16 15 15 12 11 11 18 16 15 12 15 18 12 11 11 16 12
15 14 14 12 10 12 10 18]
=====================================
[25 19 22 21 22 26 24 20 20 27 28 21 28 22 20 22 21 24 27 27 26 24 23 23
21 19 28 22 24 25 24 28 20 25 27 19 24 20 27 19 25 19 23 26 26 23 22 28
23 22 20 22 19 21 23 21 20 25 21 26 20 24 24 19 19 23 22 23 21 21 26 21
28 22 22 22 22 20 20 20 19 23 26 19 28 19 27 20 27 23 26 19 19 24 20 21
26 19 22 26 19 20 21 21 19 20 21 24 20 22 23 27 22 28 26 21 19 27 27 27
21 25 24 22 20 27 19 19]
[26 20 23 22 23 27 25 21 21 28 29 22 29 23 21 23 22 25 28 28 27 25 24 24
22 20 29 23 25 26 25 29 21 26 28 20 25 21 28 20 26 20 24 27 27 24 23 29
24 23 21 23 20 22 24 22 21 26 22 27 21 25 25 20 20 24 23 24 22 22 27 22
29 23 23 23 23 21 21 21 20 24 27 20 29 20 28 21 28 24 27 20 20 25 21 22
27 20 23 27 20 21 22 22 20 21 22 25 21 23 24 28 23 29 27 22 20 28 28 28
22 26 25 23 21 28 20 20]
=====================================
[14 13 11 13 10 9 17 11 18 13 17 10 15 11 13 17 14 10 9 13 13 10 13 10
16 13 18 14 9 14 14 9 14 18 12 16 16 14 16 15 13 10 11 12 13 9 12 18
10 17 11 12 18 10 11 9 15 16 13 16 11 13 11 12 17 15 13 9 17 17 13 15
15 16 13 12 18 18 17 10 17 10 13 18 18 17 11 11 17 13 15 9 11 12 17 13
11 17 13 12 16 11 9 13 18 16 13 17 11 16 17 16 15 15 14 11 18 13 11 16
12 10 17 14 11 9 12 15]
[15 14 12 14 11 10 18 12 19 14 18 11 16 12 14 18 15 11 10 14 14 11 14 11
17 14 19 15 10 15 15 10 15 19 13 17 17 15 17 16 14 11 12 13 14 10 13 19
11 18 12 13 19 11 12 10 16 17 14 17 12 14 12 13 18 16 14 10 18 18 14 16
16 17 14 13 19 19 18 11 18 11 14 19 19 18 12 12 18 14 16 10 12 13 18 14
12 18 14 13 17 12 10 14 19 17 14 18 12 17 18 17 16 16 15 12 19 14 12 17
13 11 18 15 12 10 13 16]
=====================================
[33 36 32 38 30 31 30 32 35 35 30 34 32 34 29 31 37 31 29 33 34 29 32 35
35 36 29 38 29 36 37 29 29 34 29 31 29 38 31 32 30 34 30 29 34 33 31 35
35 38 30 37 32 33 38 29 33 36 30 31 30 38 34 35 32 31 29 32 29 32 29 37
35 32 30 34 29 38 34 29 34 31 34 33 34 29 35 30 37 32 30 30 31 38 30 33
31 34 35 37 31 32 29 33 31 30 33 29 34 34 37 32 31 29 29 34 33 33 32 32
38 31 31 32 32 32 29 35]
[34 37 33 39 31 32 31 33 36 36 31 35 33 35 30 32 38 32 30 34 35 30 33 36
36 37 30 39 30 37 38 30 30 35 30 32 30 39 32 33 31 35 31 30 35 34 32 36
36 39 31 38 33 34 39 30 34 37 31 32 31 39 35 36 33 32 30 33 30 33 30 38
36 33 31 35 30 39 35 30 35 32 35 34 35 30 36 31 38 33 31 31 32 39 31 34
32 35 36 38 32 33 30 34 32 31 34 30 35 35 38 33 32 30 30 35 34 34 33 33
39 32 32 33 33 33 30 36]
=====================================
[6 5 7 5 5 8 6 8 8 6 7 7 7 6 5 8 8 7 7 8 4 7 6 5 7 8 6 8 8 8 6 6 8 8 7 6 6
8 4 6 7 5 2 4 8 5 4 7 8 7 4 7 6 5 2 5 8 3 6 7 8 2 4 2 4 8 7 8 7 7 7 8 8 7
8 8 3 8 7 8 7 8 7 7 8 8 6 8 4 5 7 6 8 5 4 7 5 5 5 8 6 8 8 6 8 5 3 6 8 8 6
6 3 7 8 6 7 7 3 8 5 8 7 7 5 7 6 8]
[7 6 8 6 6 9 7 9 9 7 8 8 8 7 6 9 9 8 8 9 5 8 7 6 8 9 7 9 9 9 7 7 9 9 8 7 7
9 5 7 8 6 3 5 9 6 5 8 9 8 5 8 7 6 3 6 9 4 7 8 9 3 5 3 5 9 8 9 8 8 8 9 9 8
9 9 4 9 8 9 8 9 8 8 9 9 7 9 5 6 8 7 9 6 5 8 6 6 6 9 7 9 9 7 9 6 4 7 9 9 7
7 4 8 9 7 8 8 4 9 6 9 8 8 6 8 7 9]
=====================================
[13 12 11 9 15 12 15 11 15 9 10 9 15 9 14 10 9 13 16 13 13 17 15 11
13 18 11 18 11 9 11 10 13 14 17 16 16 14 14 10 13 14 18 13 10 14 9 17
11 17 18 15 10 11 14 13 10 12 14 10 13 13 15 12 18 12 11 9 16 15 11 10
17 9 9 12 10 9 12 10 14 12 12 13 9 13 10 15 10 9 9 16 13 12 14 15
14 9 10 13 13 10 16 13 16 15 11 12 15 14 14 10 16 16 10 14 15 11 17 14
15 10 11 14 16 12 17 18]
[14 13 12 10 16 13 16 12 16 10 11 10 16 10 15 11 10 14 17 14 14 18 16 12
14 19 12 19 12 10 12 11 14 15 18 17 17 15 15 11 14 15 19 14 11 15 10 18
12 18 19 16 11 12 15 14 11 13 15 11 14 14 16 13 19 13 12 10 17 16 12 11
18 10 10 13 11 10 13 11 15 13 13 14 10 14 11 16 11 10 10 17 14 13 15 16
15 10 11 14 14 11 17 14 17 16 12 13 16 15 15 11 17 17 11 15 16 12 18 15
16 11 12 15 17 13 18 19]
=====================================
[3 8 5 8 5 7 8 5 7 3 7 6 5 5 5 5 5 6 5 4 7 5 7 8 7 5 8 4 7]
[4 9 6 9 6 8 9 6 8 4 8 7 6 6 6 6 6 7 6 5 8 6 8 9 8 6 9 5 8]
=====================================
[ 9 14 14 14 13 17 12 16 15 13 9 14 10 12 17 12 13 17 11 12 10 9 17]
[10 15 15 15 14 18 13 17 16 14 10 15 11 13 18 13 14 18 12 13 11 10 18]
=====================================
[22 25 22 20 24 26 23 26 26 21 26 28 24 28 26 20 26 28 25 27 19 27 23 25
19 24 25 26 26 19 21 28 19 19 24 25 22 19 22 21 20 19 19 19 24 19 26 22
23 23 21 21 19 21 19 28 23 23 20 25 23 22 23 20 19 26 20 27 25 25 20 19
20 19 23 25 26 21 21 23 20 20 23 22 21 21 22 19 26 23 22 19 26 23 24 26
23 22 23 24 23 24 23 24 22 23 21 25 21 24 22 22 27 24]
[23 26 23 21 25 27 24 27 27 22 27 29 25 29 27 21 27 29 26 28 20 28 24 26
20 25 26 27 27 20 22 29 20 20 25 26 23 20 23 22 21 20 20 20 25 20 27 23
24 24 22 22 20 22 20 29 24 24 21 26 24 23 24 21 20 27 21 28 26 26 21 20
21 20 24 26 27 22 22 24 21 21 24 23 22 22 23 20 27 24 23 20 27 24 25 27
24 23 24 25 24 25 24 25 23 24 22 26 22 25 23 23 28 25]
=====================================
[32 31 29 29 33 32 37 36 38 37 29 29 29 32 34 33 31 31 32 37 30 37 31 35
31 34 37]
[33 32 30 30 34 33 38 37 39 38 30 30 30 33 35 34 32 32 33 38 31 38 32 36
32 35 38]
=====================================
[44 47 46 41 45 43 39 44 39 41 42 45 48 46 44 40 39 44 41 39 39 41 48 41
40 47 45 42 43 45 42 47 39 44 48 40 40 42 41 48 40 42 39 41 40]
[45 48 47 42 46 44 40 45 40 42 43 46 49 47 45 41 40 45 42 40 40 42 49 42
41 48 46 43 44 46 43 48 40 45 49 41 41 43 42 49 41 43 40 42 41]
=====================================
[50 49 50 50 50 50 50 50 50 50 49 50 49 50 50 50 50 50 50 50 50 50 50 50
49 50 50 50 50 50 50 50 50 50 50]
[51 50 51 51 51 51 51 51 51 51 50 51 50 51 51 51 51 51 51 51 51 51 51 51
50 51 51 51 51 51 51 51 51 51 51]
=====================================
end!
为了与源数据集进行对比,简单修改代码,将batch_size改为10,并将ret[0]打印出来,具体结果为:
[['It' 'is' 'a' 'jigsaw' 'puzzle' 'still' 'being' 'put' 'together' '.'
'<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
['Let' 'me' 'tell' 'you' 'about' 'some' 'of' 'the' 'pieces' '.' '<%>'
'<%>' '<%>' '<%>' '<%>' '<%>']
['Imagine' 'the' 'first' 'piece' ':' 'a' 'man' 'burning' 'his' 'life'
''s' 'work' '.' '<%>' '<%>' '<%>']
['Words' ',' 'for' 'so' 'long' 'his' 'friends' ',' 'now' 'mocked' 'him'
'.' '<%>' '<%>' '<%>' '<%>']
['But' 'our' 'lives' 'are' 'much' 'more' 'than' 'our' 'memories' '.'
'<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
['My' 'grandmother' 'never' 'let' 'me' 'forget' 'his' 'life' '.' '<%>'
'<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
['It' 'was' 'inconceivable' 'to' 'her' 'that' 'she' 'would' 'not'
'succeed' '.' '<%>' '<%>' '<%>' '<%>' '<%>']
['The' 'greatest' 'fear' 'was' 'of' 'pirates' ',' 'rape' 'and' 'death'
'.' '<%>' '<%>' '<%>' '<%>' '<%>']
['Like' 'most' 'adults' 'on' 'the' 'boat' ',' 'my' 'mother' 'carried'
'a' 'small' 'bottle' 'of' 'poison' '.']
['After' 'three' 'months' 'in' 'a' 'refugee' 'camp' ',' 'we' 'landed'
'in' 'Melbourne' '.' '<%>' '<%>' '<%>']]
[10 10 13 12 10 9 11 11 16 13]
=====================================