1. Tensorflow高效流水线Pipeline
2. Tensorflow的数据处理中的Dataset和Iterator
3. Tensorflow生成TFRecord
4. Tensorflow的Estimator实践原理
回到顶部
我们在训练模型的时候,必须经过的第一个步骤是数据处理。在机器学习领域有一个说法,数据处理的好坏直接影响了模型结果的好坏。数据处理是至关重要的一步。
我们今天关注数据处理的另一个问题:假设我们做深度学习,数据的量随随便便就到GB的级别,那数据处理的速度对于模型的训练也很重要。经常遇到的一个情况是,数据处理的时间占了训练整个模型的大部分。
今天介绍的是Tensorflow官方推荐的数据处理方式是用Dataset API同时支持从内存和硬盘的读取,相比之前的两种方法在语法上更加简洁易懂
回到顶部
Google官方给出的Dataset API中的类图如下所示:
Dataset API还提供了四种创建Dataset的方式:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
#实例化make_one_shot_iterator对象,该对象只能读取一次
iterator = dataset.make_one_shot_iterator()
# 从iterator里取出一个元素
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作,常用的Transformation有:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
dataset = dataset.batch(32)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat(5)
# 如果repeat没有参数,则一直重复循环数据
dataset = dataset.repeat()
dataset.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([]), # tgt_output
tf.TensorShape([]),
tf.TensorShape([src_max_len])), # src_len
padding_values=(
src_eos_id, # src
0, # tgt_len -- unused
0, # src_len -- unused
0)) # mask
dataset.shard(num_shards, shard_index)
比较完整的生成dataset的代码。
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"]
#简单的生成input_fn
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset
回到顶部
生成Iterator一共有4种,复杂程度递增,个人觉得掌握前两种应该够用了,Iterator还有一个优势,目前,单次迭代器是唯一易于与 Estimator 搭配使用的类型。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
Iterator.get_next() 方法tf.Tensor 对象,每次tf.Session.run(Iterator.get_next())都会获取底层数据集中下一个元素的值。
如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 tf.errors.OutOfRangeError。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化。
sess.run(iterator.initializer)
while True:
try:
sess.run(getNextTensor)
except tf.errors.OutOfRangeError:
sess.run(iterator.initializer)
tf.contrib.data.make_saveable_from_iterator 函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
回到顶部
本文介绍了创建不同种类的Dataset和Iterator对象的基础知识,熟悉这个数据处理的步骤后,不仅复用性比较强,而且效率也能成倍的提升。