tensorflow读取数据

tensorflow有几种读取数据的方式,最常见的使用python普通加载,加载进内存,再传给模型。如下所示:

# . Load data
data = np.load('example/example.npz')
_x, _y = data["_x"], data["_y"]

#Q1. Make a placeholder for x such that it should be of dtype=int32, shape=(None, 9).
# Inputs and targets
x_pl = tf.placeholder(tf.int32, shape=(None,9))
y_hat = 45 - tf.reduce_sum(x_pl, axis=1) # We find a digit x_pl doesn't contain.

# Session
with tf.Session() as sess:
    _y_hat = sess.run(y_hat, {x_pl: _x})
    print("y_hat =", _y_hat[:30])
    print("true y =", _y[:30])

但是如果数据量较大,加载进内存过于占内存,影响速度。所以这时最好使用tensorflow提供的接口来读取训练数据

TFRecord使用

TFRecord文件在tensorflow中可以快速复制,移动,读取,存储。在我理解来看,tfrecord文件里的内容格式是tensorflow自定义的一个protobuffer。tensorflow提供了一个tf.train.Example接口,可以将写入数据填充到Example里,然后序列化成一个字符串,然后通过tf.python_io.TFRecordWriter写入本地文件

1)序列化

# Serialize
with tf.python_io.TFRecordWriter("example/tfrecord") as fout:
    for _xx, _yy in zip(_x, _y):
        ex = tf.train.Example()
        # 注意_x, _y输入得是一个列表
        ex.features.feature['x'] = tf.train.Feature(int64_list=tf.train.Int64List(value=_x))
        ex.features.feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=_y))
        fout.write(ex.SerializeToString())
        
或者
example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))

2)读取tfrecord文件
主要分成3个步骤:
1)生成一个解析队列tf.train.string_input_producer
2)tf.TFRecordReader读取解析队列,返回serialized_example对象
3 tf.parse_single_example操作将Example协议缓冲区(protocol buffer)解析为张量

读取TFReCord文件的流程如下:

def read_and_decode_single_example(fname):
    # Create a string queue
    fname_q = tf.train.string_input_producer([fname], num_epochs=1, shuffle=True)
    
    # Q3. Create a TFRecordReader
    reader = tf.TFRecordReader()
    
    # Read the string queue
    _, serialized_example = reader.read(fname_q)
    
    # Q4. Describe parsing syntax
    features = tf.parse_single_example(
        serialized_example,
        features={'x': tf.FixedLenFeature([9], tf.int64),
                  'y': tf.FixedLenFeature([1], tf.int64)}
        )
    # Output
    x = features['x']
    y = features['y']
    
    return x, y

# Ops
x, y = read_and_decode_single_example('example/tfrecord')
y_hat = 45 - tf.reduce_sum(x)

# Session
with tf.Session() as sess:
    #Q5. Initialize local variables
    sess.run(tf.local_variables_initializer())
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        while not coord.should_stop():
            _y, _y_hat = sess.run([y, y_hat])
            print(_y[0],"==", _y_hat, end="; ")
    
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        # When done, ask the threads to stop.
        coord.request_stop()
    
    # Wait for threads to finish.
    coord.join(threads)

可以看到通过parse_single_example将本地文件读取出来,同时利用start_queue_runners启动输入管道的线程,开启了多个队列线程来读取数据,将数据读入队列。如果不启动,队列是空的,会出现了无限等待,需要启动QueueRunner来填充队列。tf.FixedLenFeature()中指明数据维度和数据类型。

tf.train.start_queue_runners(sess=sess, coord=coord)

Starts all queue runners collected in the graph.
This is a companion method to add_queue_runner(). It just starts threads for all queue runners collected in the graph. It returns the list of all threads.

每个线程使用前应判断coord.should_stop()。如果调用了 coord.request_stop() ,coord.should_stop() 则返回true 。在程序的最后是用coord.join(threads) 等待所有线程结束。

tf.train.Coordinator()

A coordinator for threads.

This class implements a simple mechanism to coordinate the termination of a set of threads.
Any of the threads can call coord.request_stop() to ask for all the threads to stop.

值得注意的是,tf.train.string_input_producer中若num_epoches=None, 将循环读取文件,不会停止。若指定num_epoches为一个整数,则生成了一个local varibale。需在代码中使用tf.local_variables_initializer()来初始化local variable,如代码中所见。
tf.train.slice_input_producer用法与tf.train.string_input_producer类似,可以直接对tensor list切片,生成数据供后面使用。

如果是读取csv

with open('example/example.csv', 'w') as fout:
    fout.write(_x_str)
    
# Hyperparams
batch_size = 10

# Create a string queue
fname_q = tf.train.string_input_producer(["example/example.csv"])

# Q8. Create a TextLineReader
reader = tf.TextLineReader()

# Read the string queue
_, value = reader.read(fname_q)

# Q9. Decode value
record_defaults = [[0]]*10
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = tf.decode_csv(value, record_defaults=record_defaults,)
x = tf.stack([col1, col2, col3, col4, col5, col6, col7, col8, col9])
y = col10

如果想读取图像:

# Make fake images and save
for i in range(100):
    _x = np.random.randint(0, 256, size=(10, 10, 4))
    plt.imsave("example/image_{}.jpg".format(i), _x)

# Import jpg files
images = tf.train.match_filenames_once('example/*.jpg')

# Create a string queue
fname_q = tf.train.string_input_producer(images, num_epochs=num_epochs, shuffle=True)

# Q10. Create a WholeFileReader
reader = tf.WholeFileReader()

# Read the string queue
_, value = reader.read(fname_q)

# Q11. Decode value
img = tf.image.decode_image(value,channels=4)

# Batching
img_batch = tf.train.batch([img], shapes=([10, 10, 4]), batch_size=batch_size)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    num_samples = 0
    try:
        while not coord.should_stop():
            sess.run(img_batch)
            num_samples += batch_size
            print(num_samples, "samples have been seen")

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)

tf Dataset和Iterator对象机制


你可能感兴趣的:(tensorflow读取数据)