本文主要是对tensorflow的官方教程reading data的总结。
1. 在python程序中使用的run函数中使用feed_dict提供数据
2. 使用pipeline从文件中读取数据
3. 将数据预先载入到tensorflow的graph中,这种方式只适合数据集小的情况
一个feeding方法的程序如下:
with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))
在使用feed方法时,对于要feed的数据最好使用placeholder定义,因为这样定义的如果不feed数据,程序会报错,从而避免出现忘记feed的情况。
一个典型的从文件中读取数据的pipeline过程如下所示:
在tensorflow中主要有三种reader,分别是tf.TextLineReader()和tf.fixedLengthRecordReader()以及读取TFRecords File的tf.python_io.TFRecordReader。前者是将文件当做文本文件进行读取,中间的是将文件当做二进制文件每次读取指定长度的文件。中间的通过record_bytes=record_bytes这个命名参数来指定每次读取多长的文件。后者是使用tf.parse_single_example来解码数据的。
在该教程中,对于文本文件的读取是通过读取一个csv文件进行测试的,对于二进制文件的读取在另一个教程里的数据读取函数使用了该方法。
在读取csv文件的教程中使用的decoder是tf.decode_csv(),该函数返回的每一个数据是一个标量。所以教程中用tf.stack()将多个标量结合成一个向量。
对于读取和写入TFRecordFile的github仓库位置
写TFrecordFile的过程如下:
1.通过tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))或者
tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
2.然后将上面构建feature组成一个字典作为tf.train.Features()构造函数的feature参数的实参构
一个features
3.然后使用上面的features作为tf.train.Example的features的实参构建一个example
4.最后调用TFRecordWriter的wirte函数将example.SerializeToString()写入文件
读取TFRecordFile的过程如下:
1.先用tf.train.string_input_producer()构建一个文件名队列filequeue
2.将上面的filequeue作为参数TFRecordReader的read方法的参数,得到一个key和一个value
3.然后使用tf.parse_single_example()来得到一个features的字典,该函数的第一个参数上一步得到
的value,features参数为一个字典,字典的键值为要取出的值在之前写入该文件的过程中构造的features
所对应的键值。值为tf.FixedLenFeature([], 对应的类型名)
4.然后通过使用features[键值]访问到对应的数据。
可以通过tf.train.shuffle_batch()在前面获得到训练单个样例的符号表示的基础构建一个batch,该函数的第一个参数为一个列表,列表元素即为要batch的数据在graph中的表示, batch_size用来指定batch的大小,capacity用来指定容量,min_after_dequeue用来指定用于shuffle时随机挑选的大小。一个例子如下 :
def read_my_file_format(filename_queue):
reader = tf.SomeReader()
key, record_string = reader.read(filename_queue)
example, label = tf.some_decoder(record_string)
processed_example = some_processing(example)
return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
为了提高从多个文件中同时读取数据的并行可以使用tf.train.shuffle_batch_join()。一个使用的例子如下:
def read_my_file_format(filename_queue):
# Same as above
def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example_list = [read_my_file_format(filename_queue)
for _ in range(read_threads)]
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
tensorflow对于的queue对于稀疏数据的支持不是很好。解码稀疏数据是调用tf.parse_example(),并且是在batch操作之后调用,而前面的tf.parse_single_example是在batch之前调用
该方法输入数据只是适合数据量小的情况,主要分成两个方式,一个是存放在constant中一个是存放在variable中。存放在constant中的方式所占用的内存较大,因为它是直接将数据存放在图定义的数据结构中,可能出现重复,所以占用的内存较多。一个定义在variable中的例子如下:
training_data = ...
training_labels = ...
with tf.Session() as sess:
data_initializer = tf.placeholder(dtype=training_data.dtype,
shape=training_data.shape)
label_initializer = tf.placeholder(dtype=training_labels.dtype,
shape=training_labels.shape)
input_data = tf.Variable(data_initializer, trainable=False, collections=[])
input_labels = tf.Variable(label_initializer, trainable=False, collections=[])
...
sess.run(input_data.initializer,
feed_dict={data_initializer: training_data})
sess.run(input_labels.initializer,
feed_dict={label_initializer: training_labels})
collections=[]的作用是避免将该变量加入GraphKeys.GLOBAL_VARIABLES的collections中,从而避免在使用checkpoint文件时保存和恢复该变量。
tf.train.slice_input_producer是将它的tensor_list参数中的tensor,去掉最高维得到的结果。去掉最高维即每次指定最高维的索引,从而该维就没有了。