tf读取数据的几种方式

1.最简单的方式

import tensorflow as tf

a = tf.zeros([2,3])
b = tf.ones([2,3])
c = tf.add(a, b)
with tf.Session() as sess:
    print(sess.run(b))

直接读取已经预加载在Graph中的数据,数据量大的时候,要把所有的数据都预加载,非常不合理

2.通过feed_dict

import numpy as np
import tensorflow as tf

x = np.reshape(np.arange(6), [2,3])
a = tf.zeros([2,3])
b = tf.placeholder(dtype=tf.float32, shape=[2,3])
c = tf.add(a, b)
with tf.Session() as sess:
    print(sess.run(b, feed_dict={b:x}))

也很简单,预先设置tf.placeholder即可

3.直接从文件中读取

主要是针对大数据,效率高

  • 通过tf.train.slice_input_producer,管理线程队列读取

原理讲解:https://zhuanlan.zhihu.com/p/27238630

import tensorflow as tf

x = [[1,2],[2,3],[4,5],[6,7],[7,8],[9,10],[11,12],[13,14]]
label = ["a","b","c","d","a","b","c","d"]
# 此处shuffle=True的话不需要tf.train.shuffle_batch,batch即可
input_queues = tf.train.slice_input_producer([x, label],shuffle=False,num_epochs=2) 
x, y = tf.train.batch(input_queues,
                          num_threads=8,
                          batch_size=3,
                          capacity= 128,
                          allow_smaller_final_batch=False)
with tf.Session() as sess:
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
        while not coord.should_stop():
            print("---------")
            # 单独的分开run,因为tf的机制,它不是一起执行的,run了两次,所以x,y不对应
            print(sess.run([x, y]))
    # 如果读取到文件队列末尾会抛出此异常
    except tf.errors.OutOfRangeError:
        print("done! now lets kill all the threads……")
    finally:
        coord.request_stop()
        print('all threads are asked to stop!')
    coord.join(threads)  # 把开启的线程加入主线程,等待threads结束
    print('all threads are stopped!')

out:

---------
[array([[1, 2],
       [2, 3],
       [4, 5]]), array([b'a', b'b', b'c'], dtype=object)]
---------
[array([[ 6,  7],
       [ 7,  8],
       [ 9, 10]]), array([b'd', b'a', b'b'], dtype=object)]
---------
[array([[11, 12],
       [13, 14],
       [ 2,  3]]), array([b'c', b'd', b'b'], dtype=object)]
---------
[array([[ 1,  2],
       [13, 14],
       [11, 12]]), array([b'a', b'd', b'c'], dtype=object)]
---------
[array([[ 4,  5],
       [ 6,  7],
       [ 9, 10]]), array([b'c', b'd', b'b'], dtype=object)]
---------
done! now lets kill all the threads……
all threads are asked to stop!
all threads are stopped!

可以看出tf.train.slice_input_producer函数对于一个epoch剩下的data,并不全部输出,而是留在队列,等待下一个epoch将其取走,并且最后不足一个batch的数据直接被丢弃.

再给一个tf.train.string_input_producer+tf.WholeFileReader()+tf.train.Supervisor的用例:

# 导入tensorflow
import tensorflow as tf

filename = ['A.jpg', 'B.jpg', 'C.jpg']
# string_input_producer会产生一个文件名队列
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=3)
reader = tf.WholeFileReader() # 创建WholeFileReader对象
#key保存的是filename_queue中的文件名,value则是文件本身
key, value = reader.read(filename_queue)
# image_resized, labels = _parse_function(value, key) # 可以对原始图片处理,这里就不演示了
x, y = tf.train.batch([value, key],
                          num_threads=8,
                          batch_size=2,
                          capacity= 128,
                          allow_smaller_final_batch=False)

sv = tf.train.Supervisor(logdir='./', save_model_secs=0)
# Supervisor不需要手动启动线程管理线程;不需要手动global_variable_initializer()等,非常强大
# 之后再单独讲一讲
with sv.managed_session() as sess:
    while 1:
        if sv.should_stop(): break
        print("----------")
        print(sess.run(y)) # 因为是图片就不print x了

out:

----------
[b'A.jpg' b'B.jpg']
----------
[b'C.jpg' b'A.jpg']
----------
[b'B.jpg' b'C.jpg']
----------
[b'A.jpg' b'B.jpg']
----------

tf.train.string_input_producertf.train.slice_input_producer的区别简单来说,前者是传入文件名列表[data_file_names],用tf.WholeFileReader()等专门处理输入的类去读;后者是传入已经加载到内存中的数据,传入[data, labels],其他都一样

  • tf.data.Datase处理

原理讲解链接:https://zhuanlan.zhihu.com/p/30751039

由简入深:

dataset = tf.data.Dataset.from_tensor_slices(np.zeros([4, 10])) # 根据可迭代对象划分
iterator = dataset.make_one_shot_iterator() # 构造迭代器
element = iterator.get_next() # get_next()迭代获取元素

with tf.Session() as sess:
	for i in range(3):
		print(sess.run(element))

out:

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Dataset类中的各种操作介绍:https://www.imooc.com/article/68648

讲解最重要的几个:

def parse(x):
	return x+1
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])
dataset = dataset.map(parse) # 让dataset中的每条数据都经过parse函数的解析
dataset=dataset.batch(3, drop_remainder=False).repeat().shuffle(1000)
# Dataset对数据进行处理的函数,返回仍是Dataset类
iterator = dataset.make_one_shot_iterator() # 构造迭代器
element = iterator.get_next() # get_next()迭代获取元素

with tf.Session() as sess:
    for i in range(5):
        print(sess.run(element))

out:

[5 6]
[5 6]
[2 3 4]
[5 6]
[2 3 4]

可以看出来,drop_remainder=False时,一个epoch结束时,不足一个batch的数据仍然输出作为一个batch,每个epoch的batch顺序因为.shuffle(1000)打乱了.

当batch内每一条数据不一样长,可以用调用dataset.padded_batch()实现pad

dataset = dataset.padded_batch(batch_size,padded_shapes,..)
#函数表示原来的每一条data,pad成padded_shapes形状,再把batch_size个组合起来

一般我们将预处理好的data保存成TFRecord文件,再用tf.data.TFRecordDataset读取TFRecord文件
首先存成TFRecord文件(简单例子):

import tensorflow as tf
import numpy as np

x = 'i looove you'
x_ids = np.array([1, 2, 3])
label = 1

writer = tf.python_io.TFRecordWriter('./test.tfrecords') # 创建写类
for i in range(3):
	# 这两个features ||features = tf.train.Features|| 都要加s!!!否则报错 奇葩- -
    one_record = tf.train.Example(features = tf.train.Features(feature = {
                                "x_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(x+' this is'+str(i), 'utf-8')])),
                                "x_ids":tf.train.Feature(bytes_list=tf.train.BytesList(value=[np.append(x_ids, np.array([i])).tostring()])),
                                "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label+i]))
        }))
    writer.write(one_record.SerializeToString())
writer.close()

运行完上面代码多出一个test.tfrecords,即为我们与处理完成的文件

接下来tf.data.TFRecordDataset读取TFRecord文件:

import tensorflow as tf

def parser(example):
    features = tf.parse_single_example(example,features={
                                                "x_raw": tf.FixedLenFeature([], tf.string),
                                                "x_ids": tf.FixedLenFeature([], tf.string),
                                                "label": tf.FixedLenFeature([], tf.int64)
                                            })
    x_ids = tf.decode_raw(features["x_ids"], tf.int32)
    x_raw = features["x_raw"]
    label = features["label"]
    return x_ids, x_raw, label

dataset = tf.data.TFRecordDataset('./test.tfrecords').map(parser).shuffle(10).repeat().batch(2)
iterator = dataset.make_one_shot_iterator()
x_ids, x_raw, label = iterator.get_next()
with tf.Session() as sess:
    for i in range(3):
        a, b, c = sess.run((x_ids, x_raw, label))
        print(a, b, c)

out:

[[1 2 3 1]
 [1 2 3 0]] [b'i looove you this is1' b'i looove you this is0'] [2 1]
[[1 2 3 2]
 [1 2 3 0]] [b'i looove you this is2' b'i looove you this is0'] [3 1]
[[1 2 3 2]
 [1 2 3 1]] [b'i looove you this is2' b'i looove you this is1'] [3 2]

可见可以按batch正常取出其中数据.

现在我们对于单次data可以处理了,如果我想要更加灵活,比如可以多次使用的iterator?

为什么说make_one_shot_iterator()是单次的迭代器,原因在于,它执行完.repeat(unm)迭代器将不能重新给别的dataset使用,只能重新调用make_one_shot_iterator(),重新创建迭代器。
而接下来讲的将可以不需要重新创建,可复用的办法。
有更加详细的讲解,给新手朋友:https://blog.csdn.net/briblue/article/details/80962728

其实说可复用,不过是把多个make_one_shot_iterator()tf.data.Iterator.from_string_handle绑起来,并且用handle控制iterator迭代器控制相应dataset!

看示例:

import tensorflow as tf

def parser(example):
    features = tf.parse_single_example(example,features={
                                                "x_raw": tf.FixedLenFeature([], tf.string),
                                                "x_ids": tf.FixedLenFeature([], tf.string),
                                                "label": tf.FixedLenFeature([], tf.int64)
                                            })
    x_ids = tf.decode_raw(features["x_ids"], tf.int32)
    x_raw = features["x_raw"]
    label = features["label"]
    return x_ids, x_raw, label

dataset1 = tf.data.TFRecordDataset('./test.tfrecords').map(parser).shuffle(10).repeat().batch(2)
dataset2 = tf.data.TFRecordDataset('./test.tfrecords').map(parser).shuffle(10).repeat().batch(3)
handle = tf.placeholder(tf.string, shape=[])
# .from_string_handle()里的dataset1和dataset2要是相同格式的数据,不然不能'绑'起来
iterator = tf.data.Iterator.from_string_handle(handle, dataset1.output_types, dataset2.output_shapes)
element = iterator.get_next()

iterator1 = dataset1.make_one_shot_iterator()
iterator2 = dataset2.make_one_shot_iterator()

with tf.Session() as sess:
	for i in range(3):
        dataset1_handle = sess.run(iterator1.string_handle())
        a1, b1, c1 = sess.run(element, feed_dict={handle: dataset1_handle})
        print(a1, b1, c1)
        dataset2_handle = sess.run(iterator2.string_handle())
        a2, b2, c2 = sess.run(element, feed_dict={handle: dataset2_handle})
        print(a2, b2, c2)

out:

[[1 2 3 1]
 [1 2 3 0]] [b'i looove you this is1' b'i looove you this is0'] [2 1]
[[1 2 3 1]
 [1 2 3 2]
 [1 2 3 0]] [b'i looove you this is1' b'i looove you this is2'
 b'i looove you this is0'] [2 3 1]
[[1 2 3 2]
 [1 2 3 1]] [b'i looove you this is2' b'i looove you this is1'] [3 2]
[[1 2 3 1]
 [1 2 3 0]
 [1 2 3 2]] [b'i looove you this is1' b'i looove you this is0'
 b'i looove you this is2'] [2 1 3]
[[1 2 3 0]
 [1 2 3 2]] [b'i looove you this is0' b'i looove you this is2'] [1 3]
[[1 2 3 2]
 [1 2 3 0]
 [1 2 3 1]] [b'i looove you this is2' b'i looove you this is0'
 b'i looove you this is1'] [3 1 2]

可以看见batch_size为2/3的交替输出;
需要哪个dataset,则sess.run(iterator.string_handle())对应dataset的iterator的string_handle,iterator就切换到哪个dataset处理返回一个batch

可以用在train,valid交替进行训练

上文所有的tf.data.Dataset方法,都没有检测迭代器是否为空的情况,原因在于.repeat(),并且通过 sess中的for循环控制何时停止迭代,正常使用建议写上,规范代码.

with tf.Session() as sess:
    try:
        while True:
            print(sess.run(element))
    except tf.errors.OutOfRangeError:
        print("DONE")

本博文主要讲解了两种直接从文件读取方式的异同:

  • 总的来说,两种方式都是通过多线程队列来实现高效的组织输入,tf.data.Dataset不需要手动管理和启动线程队列;tf.train.slice_input_producer结合sv也可以实现这个目的。
  • tf.train.slice_input_producer对于不是最后一个epoch的情况,最后不满一个batch的数据仍然留在队列中,供下一次使用,对于最后一个epoch的情况,最后不满一个batch的数据直接丢弃;
  • tf.data.Dataset通过.batch(drop_remainder=False/True)来控制是否丢弃,要实现tf.train.slice_input_producer的效果,需要把.repeat()放在.batch()之前;如果顺序倒过来,像博主在本文中举的所有例子,则在每个epoch都会存在一个小尾巴!可以说很有意思了!!!
  • 都很方便使用,看个人习惯使用吧

.#

你可能感兴趣的:(tf读取数据的几种方式)