使用Tensorflow Dataset读取数据

转载自:https://blog.csdn.net/foreseerwang/article/details/80572182

    注意,Dataset.from_generator在旧版Tensorflow中没有,在1.4版本以上才有tf.data.Dataset。

      tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择(if 跳转)等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。

      Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:

import numpy as np

import tensorflow as tf

def data_generator():

    dataset = np.array(range(5))

    for d in dataset:

        yield d

dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))

dataset = dataset.repeat(3)

dataset = dataset.batch(4)

iterator = dataset.make_one_shot_iterator() #one-shot iterator 是最简单的一种遍历器。这种遍历器只支持#遍历单一dataset,并且还不需要显式的初始化。

one_element = iterator.get_next()

with tf.Session() as sess:

    try:

        batch_num=0

        while True:

            one_batch = sess.run(one_element)

            print('Batch No. %d:' % batch_num)

            print(one_batch)

            print('')

            batch_num+=1

    except tf.errors.OutOfRangeError:

        print('end!')

很显然,这个的输出如下:

Batch No. 0:

[0 1 2 3]

Batch No. 1:

[4 0 1 2]

Batch No. 2:

[3 4 0 1]

Batch No. 3:

[2 3 4]

end!

下面给出一个复杂的问题。假设需要输入如下序列:A BA C BC…其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。

import io

import numpy as np

import tensorflow as tf

class DataFeeder:

    def __init__(self, filenames):

        self.filenames = filenames

    def file_readline(self):

        for filename in self.filenames:

            fr = io.open(filename, 'r', encoding='utf-8')

            while True:

                file_line = fr.readline()

                if not file_line:

                    break

                datalist = file_line.split()

                # if datalist is a list of filename, file contents can

                # be read and appendded here.

                yield np.asarray(datalist, dtype='int32')

            fr.close()

    def generate_batch(self, batch_size, num_epochs=None):

        dataset = tf.data.Dataset.from_generator(self.file_readline,

                                                tf.int32,

                                                tf.TensorShape([None]))

        dataset = dataset.repeat(num_epochs)

        dataset = dataset.padded_batch(

            batch_size,

            padded_shapes=tf.TensorShape([3]),

            padding_values=-1)

        iterator = dataset.make_one_shot_iterator()

        out_batch = iterator.get_next()

        return out_batch

filenames = ['a.txt', 'b.txt', 'c.txt']

data_feeder = DataFeeder(filenames)

one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)

with tf.Session() as sess:

    try:

        batch_num = 0

        while True:

            data_batch = sess.run(one_batch)

            print('Batch No. %d:' % batch_num)

            print(data_batch)

            print('')

            batch_num+=1

    except tf.errors.OutOfRangeError:

        print('end!')

其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:

a.txt:

1 2 3

2 3

3

b.txt:

4 5

6 7 8

9

c.txt:

10 11 12

13 14

15

运行以上代码的输出为:

Batch No. 0:

[[ 1  2  3]

[ 2  3 -1]]

Batch No. 1:

[[ 3 -1 -1]

[ 4  5 -1]]

Batch No. 2:

[[ 6  7  8]

[ 9 -1 -1]]

Batch No. 3:

[[10 11 12]

[13 14 -1]]

Batch No. 4:

[[15 -1 -1]]

end!

你可能感兴趣的:(使用Tensorflow Dataset读取数据)