关于tensorflow dataset API(bucket) 的一些学习记录----NMT

在学习nmt源码时对数据处理部分中的bucket有一些疑惑,现以官方示例中的 “tst2012.en”作为源数据集以及目标数据集,以“vocab.en”作为对应的词表,做了一些尝试,具体如下:

#-*- coding:utf-8 -*-

import tensorflow as tf
from tensorflow.python.ops import lookup_ops


num_threads = 4
batch_size = 128
src_max_len = tgt_max_len = 50
source_reverse = False
num_buckets = 5
random_seed = 0

#词表
src_vocab_table = tgt_vocab_table = lookup_ops.index_table_from_file( "./nmt_data/vocab.en", default_value=0)

#数据集的生成,每一行内容作为一个元素
src_dataset = tf.contrib.data.TextLineDataset('./nmt_data/tst2012.en')
tgt_dataset = tf.contrib.data.TextLineDataset('./nmt_data/tst2012.en')
# reverse_src_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
#     "./nmt_data/vocab.en", default_value=0)

#sos eos及unk在词表中对应的id
output_buffer_size = batch_size * 1000
src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant('')),  tf.int32)
tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant('')), tf.int32)
tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant('')), tf.int32)

#合并
src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset))

#为了便于观察,不对数据集进行乱序
#src_tgt_dataset = src_tgt_dataset.shuffle(
#   output_buffer_size, random_seed)

#将数据集中的元素以空格分成单个单词
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (
        tf.string_split([src]).values, tf.string_split([tgt]).values),
    num_threads=num_threads,
    output_buffer_size=output_buffer_size)

# Filter zero length input sequences.
src_tgt_dataset = src_tgt_dataset.filter(
    lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))

	#限制数据集的最大长度为50,超出的舍弃
if src_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src[:src_max_len], tgt),
        num_threads=num_threads,
        output_buffer_size=output_buffer_size)
if tgt_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src, tgt[:tgt_max_len]),
        num_threads=num_threads,
        output_buffer_size=output_buffer_size)
if source_reverse:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (tf.reverse(src, axis=[0]), tgt),
        num_threads=num_threads,
        output_buffer_size=output_buffer_size)

		
#为了便于观察,不再将数据集转换成对应词表中的id
# Convert the word strings to ids.  Word strings that are not in the
# vocab get the lookup table's default_value integer.
# src_tgt_dataset = src_tgt_dataset.map(
#     lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
#                       tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
#     num_threads=num_threads, output_buffer_size=output_buffer_size)
# Create a tgt_input prefixed with  and a tgt_output suffixed with .
# src_tgt_dataset = src_tgt_dataset.map(
#     lambda src, tgt: (src,
#                       tf.concat(([tgt_sos_id], tgt), 0),
#                       tf.concat((tgt, [tgt_eos_id]), 0)),
#     num_threads=num_threads, output_buffer_size=output_buffer_size)


#给目标数据集分别添加起始标志和结束标志
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([''], tgt), 0),
                      tf.concat((tgt, ['']), 0)),
    num_threads=num_threads, output_buffer_size=output_buffer_size)

# Add in the word counts.  Subtract one from the target to avoid counting
# the target_input  tag (resp. target_output  tag).
#添加源数据集及目标数据集的size
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (
        src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
    num_threads=num_threads,
    output_buffer_size=output_buffer_size)

#分桶并对分桶后的数据补位
def batching_func(x):
    return x.padded_batch(
        batch_size,
        # The first three entries are the source and target line rows;
        # these have unknown-length vectors.  The last two entries are
        # the source and target row sizes; these are scalars.
        padded_shapes=(tf.TensorShape([None]),  # src
                       tf.TensorShape([None]),  # tgt_input
                       tf.TensorShape([None]),  # tgt_output
                       tf.TensorShape([]),  # src_len
                       tf.TensorShape([])),  # tgt_len
        # Pad the source and target sequences with eos tokens.
        # (Though notice we don't generally need to do this since
        # later on we will be masking out calculations past the true sequence.
        padding_values=('<%>',  # src
                        '<&>',  # tgt_input
                        '<*>',  # tgt_output
                        0,  # src_len -- unused
                        0))  # tgt_len -- unused

if num_buckets > 1:
    def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
    #def key_func(unused_1, src_len, tgt_len):
        # Calculate bucket_width by maximum source sequence length.
        # Pairs with length [0, bucket_width) go to bucket 0, length
        # [bucket_width, 2 * bucket_width) go to bucket 1, etc.  Pairs with length
        # over ((num_bucket-1) * bucket_width) words all go into the last bucket.
        if src_max_len:
            bucket_width = (src_max_len + num_buckets - 1) // num_buckets
        else:
            bucket_width = 10

        # Bucket sentence pairs by the length of their source sentence and target
        # sentence.
        bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
        return tf.to_int64(tf.minimum(num_buckets, bucket_id))


    def reduce_func(unused_key, windowed_data):
        return batching_func(windowed_data)


    batched_dataset = src_tgt_dataset.group_by_window(
        key_func=key_func, reduce_func=reduce_func, window_size=batch_size)
else:
    batched_dataset = batching_func(src_tgt_dataset)

#迭代器,通过get_next()访问数据集中的下一个元素
batched_iter = batched_dataset.make_initializable_iterator()
(src_ids, tgt_input_ids, tgt_output_ids, src_len, tgt_len) = (
      batched_iter.get_next())

with tf.Session() as sess:
    sess.run(tf.tables_initializer())
    sess.run(batched_iter.initializer)
    try:
        while True:
            ret = sess.run(
                (src_ids, tgt_input_ids, tgt_output_ids, src_len, tgt_len))

            # print(ret)
            print(ret[3])
            print('\n')
            print(ret[4])
            print("=====================================")
    except tf.errors.OutOfRangeError:
        print('end!')

运行上述代码,结果为:

[10 10 13 12 10  9 11 11 16 13 15 13 12 16 17 13 10 17 16  9 10 10 10 10
 15 10 11 18 18  9 13 13  9 15 13 15  9 14 15  9 15  9 15 17  9 13 16 16
 10 16 14 12 12 18  9 12 11 13 11 13 12 15 11 10 12 11 17 10  9 12 16  9
 15 14  9 15 13 17 10 12 16 18 13 14 13 12 14 10 14 10 12  9 12 11 15  9
 14 10 14 10 15 12 11 10 12 14 15 10 16 12 15 15 17 15 12 18 10 10 15 11
 15  9 10 12 12 15 12 11]


[11 11 14 13 11 10 12 12 17 14 16 14 13 17 18 14 11 18 17 10 11 11 11 11
 16 11 12 19 19 10 14 14 10 16 14 16 10 15 16 10 16 10 16 18 10 14 17 17
 11 17 15 13 13 19 10 13 12 14 12 14 13 16 12 11 13 12 18 11 10 13 17 10
 16 15 10 16 14 18 11 13 17 19 14 15 14 13 15 11 15 11 13 10 13 12 16 10
 15 11 15 11 16 13 12 11 13 15 16 11 17 13 16 16 18 16 13 19 11 11 16 12
 16 10 11 13 13 16 13 12]
=====================================
[28 20 20 24 23 21 19 28 20 22 22 20 23 19 27 28 20 23 25 20 19 20 19 26
 20 23 25 19 22 21 19 25 24 22 20 23 26 26 22 24 19 23 22 27 19 23 26 21
 20 24 22 28 27 21 25 25 22 27 26 25 25 28 20 22 19 20 28 28 26 23 20 23
 19 21 22 23 20 23 21 28 21 25 19 25 26 19 25 23 20 20 21 22 27 20 22 26
 21 19 20 24 28 20 24 25 19 19 25 19 20 23 26 21 22 26 24 20 19 25 23 25
 20 19 25 28 26 25 21 25]


[29 21 21 25 24 22 20 29 21 23 23 21 24 20 28 29 21 24 26 21 20 21 20 27
 21 24 26 20 23 22 20 26 25 23 21 24 27 27 23 25 20 24 23 28 20 24 27 22
 21 25 23 29 28 22 26 26 23 28 27 26 26 29 21 23 20 21 29 29 27 24 21 24
 20 22 23 24 21 24 22 29 22 26 20 26 27 20 26 24 21 21 22 23 28 21 23 27
 22 20 21 25 29 21 25 26 20 20 26 20 21 24 27 22 23 27 25 21 20 26 24 26
 21 20 26 29 27 26 22 26]
=====================================
[ 9  9 17 14 16 11 13  9 10 15 15 17 16 14 10 14 18 11 13 16 11 18 18 15
 15 13 16 17 12 12 11 15 15 17 11 17 15 13 15 15 18 13 11 15 10  9 15 10
 16 17 10 11 17 13 11 13 13 10 12 12 14 14 14 12 12 10 11 15 15 18 16 18
 10 10 18 13 13 15 18 16 12 14 14  9 14 12 18 10 16 11 12 10 16 18 11 16
 11 11 18 14 12 18 17 16 16 12 13 12 16 15 13 18 13 12 14 16 15 15 11 12
 18 16 12 12 13  9 13 17]


[10 10 18 15 17 12 14 10 11 16 16 18 17 15 11 15 19 12 14 17 12 19 19 16
 16 14 17 18 13 13 12 16 16 18 12 18 16 14 16 16 19 14 12 16 11 10 16 11
 17 18 11 12 18 14 12 14 14 11 13 13 15 15 15 13 13 11 12 16 16 19 17 19
 11 11 19 14 14 16 19 17 13 15 15 10 15 13 19 11 17 12 13 11 17 19 12 17
 12 12 19 15 13 19 18 17 17 13 14 13 17 16 14 19 14 13 15 17 16 16 12 13
 19 17 13 13 14 10 14 18]
=====================================
[7 5 6 5 8 7 8 4 5 5 6 4 6 7 7 6 7 7 5 4 7 6 6 7 5 8 8 6 6 4 8 8 8 7 8 5 6
 7 4 8 7 7 8 7 7 6 6 6 8 8 8 8 6 6 5 8 4 8 4 6 8 6 6 8 7 7 7 7 7 5 8 8 6 7
 6 6 7 8 5 8 6 7 6 5 7 6 6 6 8 8 6 8 8 5 6 3 8 7 6 8 7 5 7 8 4 3 6 5 7 6 6
 8 8 8 8 7 8 5 5 5 5 5 8 7 6 6 8 7]


[8 6 7 6 9 8 9 5 6 6 7 5 7 8 8 7 8 8 6 5 8 7 7 8 6 9 9 7 7 5 9 9 9 8 9 6 7
 8 5 9 8 8 9 8 8 7 7 7 9 9 9 9 7 7 6 9 5 9 5 7 9 7 7 9 8 8 8 8 8 6 9 9 7 8
 7 7 8 9 6 9 7 8 7 6 8 7 7 7 9 9 7 9 9 6 7 4 9 8 7 9 8 6 8 9 5 4 7 6 8 7 7
 9 9 9 9 8 9 6 6 6 6 6 9 8 7 7 9 8]
=====================================
[16 16  9 14 12 12 10 10 11 18  9 11 15 11 16 14  9 15 14 18 15 15 13 15
 14 13 14 15  9 13 17 14  9 11 18 18 10 16 10 15 18 14 13 13 13 15 11 14
 10 11 17 15 10 17  9 11  9  9 11 11  9 11 14 10 10 10 10 12 12 12 11 17
 14 15 16 13 13 15 11 14 13  9 14 10  9 15 13 14 12 16 18 12 15 14 17  9
 16 14 13  9 10 11 18 15 14 14 11 10 10 17 15 14 11 14 17 11 10 10 15 11
 14 13 13 11  9 11  9 17]


[17 17 10 15 13 13 11 11 12 19 10 12 16 12 17 15 10 16 15 19 16 16 14 16
 15 14 15 16 10 14 18 15 10 12 19 19 11 17 11 16 19 15 14 14 14 16 12 15
 11 12 18 16 11 18 10 12 10 10 12 12 10 12 15 11 11 11 11 13 13 13 12 18
 15 16 17 14 14 16 12 15 14 10 15 11 10 16 14 15 13 17 19 13 16 15 18 10
 17 15 14 10 11 12 19 16 15 15 12 11 11 18 16 15 12 15 18 12 11 11 16 12
 15 14 14 12 10 12 10 18]
=====================================
[25 19 22 21 22 26 24 20 20 27 28 21 28 22 20 22 21 24 27 27 26 24 23 23
 21 19 28 22 24 25 24 28 20 25 27 19 24 20 27 19 25 19 23 26 26 23 22 28
 23 22 20 22 19 21 23 21 20 25 21 26 20 24 24 19 19 23 22 23 21 21 26 21
 28 22 22 22 22 20 20 20 19 23 26 19 28 19 27 20 27 23 26 19 19 24 20 21
 26 19 22 26 19 20 21 21 19 20 21 24 20 22 23 27 22 28 26 21 19 27 27 27
 21 25 24 22 20 27 19 19]


[26 20 23 22 23 27 25 21 21 28 29 22 29 23 21 23 22 25 28 28 27 25 24 24
 22 20 29 23 25 26 25 29 21 26 28 20 25 21 28 20 26 20 24 27 27 24 23 29
 24 23 21 23 20 22 24 22 21 26 22 27 21 25 25 20 20 24 23 24 22 22 27 22
 29 23 23 23 23 21 21 21 20 24 27 20 29 20 28 21 28 24 27 20 20 25 21 22
 27 20 23 27 20 21 22 22 20 21 22 25 21 23 24 28 23 29 27 22 20 28 28 28
 22 26 25 23 21 28 20 20]
=====================================
[14 13 11 13 10  9 17 11 18 13 17 10 15 11 13 17 14 10  9 13 13 10 13 10
 16 13 18 14  9 14 14  9 14 18 12 16 16 14 16 15 13 10 11 12 13  9 12 18
 10 17 11 12 18 10 11  9 15 16 13 16 11 13 11 12 17 15 13  9 17 17 13 15
 15 16 13 12 18 18 17 10 17 10 13 18 18 17 11 11 17 13 15  9 11 12 17 13
 11 17 13 12 16 11  9 13 18 16 13 17 11 16 17 16 15 15 14 11 18 13 11 16
 12 10 17 14 11  9 12 15]


[15 14 12 14 11 10 18 12 19 14 18 11 16 12 14 18 15 11 10 14 14 11 14 11
 17 14 19 15 10 15 15 10 15 19 13 17 17 15 17 16 14 11 12 13 14 10 13 19
 11 18 12 13 19 11 12 10 16 17 14 17 12 14 12 13 18 16 14 10 18 18 14 16
 16 17 14 13 19 19 18 11 18 11 14 19 19 18 12 12 18 14 16 10 12 13 18 14
 12 18 14 13 17 12 10 14 19 17 14 18 12 17 18 17 16 16 15 12 19 14 12 17
 13 11 18 15 12 10 13 16]
=====================================
[33 36 32 38 30 31 30 32 35 35 30 34 32 34 29 31 37 31 29 33 34 29 32 35
 35 36 29 38 29 36 37 29 29 34 29 31 29 38 31 32 30 34 30 29 34 33 31 35
 35 38 30 37 32 33 38 29 33 36 30 31 30 38 34 35 32 31 29 32 29 32 29 37
 35 32 30 34 29 38 34 29 34 31 34 33 34 29 35 30 37 32 30 30 31 38 30 33
 31 34 35 37 31 32 29 33 31 30 33 29 34 34 37 32 31 29 29 34 33 33 32 32
 38 31 31 32 32 32 29 35]


[34 37 33 39 31 32 31 33 36 36 31 35 33 35 30 32 38 32 30 34 35 30 33 36
 36 37 30 39 30 37 38 30 30 35 30 32 30 39 32 33 31 35 31 30 35 34 32 36
 36 39 31 38 33 34 39 30 34 37 31 32 31 39 35 36 33 32 30 33 30 33 30 38
 36 33 31 35 30 39 35 30 35 32 35 34 35 30 36 31 38 33 31 31 32 39 31 34
 32 35 36 38 32 33 30 34 32 31 34 30 35 35 38 33 32 30 30 35 34 34 33 33
 39 32 32 33 33 33 30 36]
=====================================
[6 5 7 5 5 8 6 8 8 6 7 7 7 6 5 8 8 7 7 8 4 7 6 5 7 8 6 8 8 8 6 6 8 8 7 6 6
 8 4 6 7 5 2 4 8 5 4 7 8 7 4 7 6 5 2 5 8 3 6 7 8 2 4 2 4 8 7 8 7 7 7 8 8 7
 8 8 3 8 7 8 7 8 7 7 8 8 6 8 4 5 7 6 8 5 4 7 5 5 5 8 6 8 8 6 8 5 3 6 8 8 6
 6 3 7 8 6 7 7 3 8 5 8 7 7 5 7 6 8]


[7 6 8 6 6 9 7 9 9 7 8 8 8 7 6 9 9 8 8 9 5 8 7 6 8 9 7 9 9 9 7 7 9 9 8 7 7
 9 5 7 8 6 3 5 9 6 5 8 9 8 5 8 7 6 3 6 9 4 7 8 9 3 5 3 5 9 8 9 8 8 8 9 9 8
 9 9 4 9 8 9 8 9 8 8 9 9 7 9 5 6 8 7 9 6 5 8 6 6 6 9 7 9 9 7 9 6 4 7 9 9 7
 7 4 8 9 7 8 8 4 9 6 9 8 8 6 8 7 9]
=====================================
[13 12 11  9 15 12 15 11 15  9 10  9 15  9 14 10  9 13 16 13 13 17 15 11
 13 18 11 18 11  9 11 10 13 14 17 16 16 14 14 10 13 14 18 13 10 14  9 17
 11 17 18 15 10 11 14 13 10 12 14 10 13 13 15 12 18 12 11  9 16 15 11 10
 17  9  9 12 10  9 12 10 14 12 12 13  9 13 10 15 10  9  9 16 13 12 14 15
 14  9 10 13 13 10 16 13 16 15 11 12 15 14 14 10 16 16 10 14 15 11 17 14
 15 10 11 14 16 12 17 18]


[14 13 12 10 16 13 16 12 16 10 11 10 16 10 15 11 10 14 17 14 14 18 16 12
 14 19 12 19 12 10 12 11 14 15 18 17 17 15 15 11 14 15 19 14 11 15 10 18
 12 18 19 16 11 12 15 14 11 13 15 11 14 14 16 13 19 13 12 10 17 16 12 11
 18 10 10 13 11 10 13 11 15 13 13 14 10 14 11 16 11 10 10 17 14 13 15 16
 15 10 11 14 14 11 17 14 17 16 12 13 16 15 15 11 17 17 11 15 16 12 18 15
 16 11 12 15 17 13 18 19]
=====================================
[3 8 5 8 5 7 8 5 7 3 7 6 5 5 5 5 5 6 5 4 7 5 7 8 7 5 8 4 7]


[4 9 6 9 6 8 9 6 8 4 8 7 6 6 6 6 6 7 6 5 8 6 8 9 8 6 9 5 8]
=====================================
[ 9 14 14 14 13 17 12 16 15 13  9 14 10 12 17 12 13 17 11 12 10  9 17]


[10 15 15 15 14 18 13 17 16 14 10 15 11 13 18 13 14 18 12 13 11 10 18]
=====================================
[22 25 22 20 24 26 23 26 26 21 26 28 24 28 26 20 26 28 25 27 19 27 23 25
 19 24 25 26 26 19 21 28 19 19 24 25 22 19 22 21 20 19 19 19 24 19 26 22
 23 23 21 21 19 21 19 28 23 23 20 25 23 22 23 20 19 26 20 27 25 25 20 19
 20 19 23 25 26 21 21 23 20 20 23 22 21 21 22 19 26 23 22 19 26 23 24 26
 23 22 23 24 23 24 23 24 22 23 21 25 21 24 22 22 27 24]


[23 26 23 21 25 27 24 27 27 22 27 29 25 29 27 21 27 29 26 28 20 28 24 26
 20 25 26 27 27 20 22 29 20 20 25 26 23 20 23 22 21 20 20 20 25 20 27 23
 24 24 22 22 20 22 20 29 24 24 21 26 24 23 24 21 20 27 21 28 26 26 21 20
 21 20 24 26 27 22 22 24 21 21 24 23 22 22 23 20 27 24 23 20 27 24 25 27
 24 23 24 25 24 25 24 25 23 24 22 26 22 25 23 23 28 25]
=====================================
[32 31 29 29 33 32 37 36 38 37 29 29 29 32 34 33 31 31 32 37 30 37 31 35
 31 34 37]


[33 32 30 30 34 33 38 37 39 38 30 30 30 33 35 34 32 32 33 38 31 38 32 36
 32 35 38]
=====================================
[44 47 46 41 45 43 39 44 39 41 42 45 48 46 44 40 39 44 41 39 39 41 48 41
 40 47 45 42 43 45 42 47 39 44 48 40 40 42 41 48 40 42 39 41 40]


[45 48 47 42 46 44 40 45 40 42 43 46 49 47 45 41 40 45 42 40 40 42 49 42
 41 48 46 43 44 46 43 48 40 45 49 41 41 43 42 49 41 43 40 42 41]
=====================================
[50 49 50 50 50 50 50 50 50 50 49 50 49 50 50 50 50 50 50 50 50 50 50 50
 49 50 50 50 50 50 50 50 50 50 50]


[51 50 51 51 51 51 51 51 51 51 50 51 50 51 51 51 51 51 51 51 51 51 51 51
 50 51 51 51 51 51 51 51 51 51 51]
=====================================
end!

为了与源数据集进行对比,简单修改代码,将batch_size改为10,并将ret[0]打印出来,具体结果为:

[['It' 'is' 'a' 'jigsaw' 'puzzle' 'still' 'being' 'put' 'together' '.'
  '<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
 ['Let' 'me' 'tell' 'you' 'about' 'some' 'of' 'the' 'pieces' '.' '<%>'
  '<%>' '<%>' '<%>' '<%>' '<%>']
 ['Imagine' 'the' 'first' 'piece' ':' 'a' 'man' 'burning' 'his' 'life'
  ''s' 'work' '.' '<%>' '<%>' '<%>']
 ['Words' ',' 'for' 'so' 'long' 'his' 'friends' ',' 'now' 'mocked' 'him'
  '.' '<%>' '<%>' '<%>' '<%>']
 ['But' 'our' 'lives' 'are' 'much' 'more' 'than' 'our' 'memories' '.'
  '<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
 ['My' 'grandmother' 'never' 'let' 'me' 'forget' 'his' 'life' '.' '<%>'
  '<%>' '<%>' '<%>' '<%>' '<%>' '<%>']
 ['It' 'was' 'inconceivable' 'to' 'her' 'that' 'she' 'would' 'not'
  'succeed' '.' '<%>' '<%>' '<%>' '<%>' '<%>']
 ['The' 'greatest' 'fear' 'was' 'of' 'pirates' ',' 'rape' 'and' 'death'
  '.' '<%>' '<%>' '<%>' '<%>' '<%>']
 ['Like' 'most' 'adults' 'on' 'the' 'boat' ',' 'my' 'mother' 'carried'
  'a' 'small' 'bottle' 'of' 'poison' '.']
 ['After' 'three' 'months' 'in' 'a' 'refugee' 'camp' ',' 'we' 'landed'
  'in' 'Melbourne' '.' '<%>' '<%>' '<%>']]
[10 10 13 12 10  9 11 11 16 13]
=====================================

 

你可能感兴趣的:(深度学习)