TensorFlow提供了一种内置的API—ataset,使得我们可以很容易地就利用输入管道的方式输入数据。可以像队列读取数据那样,生产batch、数据增强等等。
官方介绍
tf.data.Dataset
可以表示为一些元素的序列,该元素序列可以是列表、元组甚至是字典。比如对于图像通道,元素可以是单独的数据样本,也可以是成对的(样本+label),这里提供了两种不同的创建dataset的方式:
Dataset.from_tensor_slices():从数据中返回一个切片,也就是单个数据信息
Dataset.batch():对数据应用变换,使其返回一个batch
tf.data.Iterator是从数据集中提取元素的主要方法,通过Iterator.get_next()产生Dataset下一个元素。最简单的迭代器是"one-shot iterator",它可以对Dataset迭代一次;对于复杂的情况,Iterator.initializer可以让你重新启动和参数化一个迭代器,这样就可以在一个程序中多次加载训练集和验证集。
# 创建一个Dataset,(此例是from tensor创建)
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) # tensor
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32))) # 元组
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)}) # 字典
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
一旦创建了Dataset后,就可以利用迭代器获取元素,目前有四种迭代形式:
one-shot 是最简单的一种迭代器,它仅仅支持迭代Dataset一遍,举例如下:
该迭代器最为简单,简要来说就是创建数据源————迭代器————取数据
dataset = tf.data.Dataset.range(5) # Dataset.range(5) == [0, 1, 2, 3, 4]
# 创建一个迭代器
# 该迭代器默认是已经初始化过的,并且不支持重新初始化(不支持数据源的改动)
iterator = dataset.make_one_shot_iterator()
# 获取数据
next_element = iterator.get_next()
sess = tf.InteractiveSession()
for i in range(5):
value = sess.run(next_element)
print(value)
# 输出
0
1
2
3
4
initializable迭代器,它可以让你将Dataset用参数表示,结合placeholder就可以动态的对数据源进行调整,举例如下:
该迭代形式可以和placeholder结合,动态调整数据,比起第一种灵活性提高
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
# 创建一个迭代器,不同的是该迭代器是未被初始化的
# 在取数据之前,必须执行iterator.initializer节点进行初始化
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# 获取数据之前要,初始化一个迭代器,并动态的传入你想传入的值
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
# 获取数据
value = sess.run(next_element)
reinitializable 可以用多个不同的Dataset对迭代器进行重新初始化,例如训练模型的时候,一般都会有训练集和验证集,该迭代方式能使你通过传入不同数据源——对应不同数据初始化——就可以迭代取出不同数据源的元素数据。举例如下:
相比前两种,复杂度更高,可以在多个数据上使用迭代器
# 定义两个数据源
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# 依据给定属性创建一个迭代器(未被初始化的)
# 该迭代器可以被重用,也就是可以在多个Dataset上使用,它不受特别的Dataset所限制
# 但此迭代器没有initializer,为了进行初始化,可以使用Iterator.make_initializer(dataset)
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
# 对传入的不同数据源定义迭代器初始化节点
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
feedable迭代器
该类迭代器可以和placeholder结合使用,同时
最后一种,也是最复杂最灵活的一种,可以和placeholder结合使用,并且可以传入多种迭代器
# Define training and validation datasets with the same structure.
# map()表示一个映射函数,比如x的平方也是一个映射,用于对数据进行调整
# repeat()表示将Dataset进行重复数次,如果默认参数,则表示无限重复(不做限制)
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
handle = tf.placeholder(tf.string, shape=[])
# 创建迭代器,该迭代器也是未初始化的
# feedable iterator通过给定的句柄和数据源结构进行定义
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# 可以传入多种迭代器(通过执行迭代器句柄,如下所示)
# 可以看到,定义了两种迭代器
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# 通过run `Iterator.string_handle()`方法,返回一个字符串tensor,代表对应迭代器
# 返回值用于feed_dict
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
还有一点需要注意,当迭代完到底时,如果在进行Iterator.get_next()的话会报错,因为迭代器已经处于不可用的状态,解决办法是重新初始化迭代器。
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess = tf.InteractiveSession()
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
以numpy为数据源
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# numpy数据传入后,会进行tf.constant()操作————也就是转换成tensor
# 但这只适用于小数据集使用,因为数组会被多次复制,会导致内存不够、内存浪费
一个解决办法是使用placeholder。
以TFRecord为数据源
当读取大数据的时候,我们通常会把数据集制作成TFRecord,当然tf.data API也支持读取TFRecord格式。
# 创建数据
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
filenames可以是一个字符串类型、字符串列表或者一个tf.Tensor类型的字符串
-------------------------------------------------------------------------------
# 也可以使用placeholder的形式
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
# 创建迭代器
iterator = dataset.make_initializable_iterator()
# 传入placeholder
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# 同上
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
生产batch
inc_dataset = tf.data.Dataset.range(9)
dec_dataset = tf.data.Dataset.range(0, -9, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
# 这样就生成batch了。。。直接用,多的不说
batched_dataset = dataset.batch(3)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.InteractiveSession()
print(sess.run(next_element)) # (array([0, 1, 2], dtype=int64), array([ 0, -1, -2], dtype=int64))
print(sess.run(next_element)) # (array([3, 4, 5], dtype=int64), array([-3, -4, -5], dtype=int64))
print(sess.run(next_element)) # (array([6, 7, 8], dtype=int64), array([-6, -7, -8], dtype=int64))
Randomly shuffling
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)