Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。
Google官方给出的Dataset API中的类图:
我们本文只关注Dataset的一类特殊的操作:Transformation,即map,shuffle,repeat,batch等。
在正式介绍之前,我们再回忆一下深度学习中的一些基本概念。
最简单的情况如下:
# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
结果为:
[0 1 2 3 4 5]
[6 7 8 9]
但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。
repeat方法可以解决上述问题,repeat的功能就是将整个数据重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(2)就可以将之变成2个epoch:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
结果如下:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
当然,如果觉得每次都需要设置repeat的次数麻烦,我们也可以不设置repeat,代码如下:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(6):
value = sess.run(next_element)
print(value)
结果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
仔细看可以知道上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小,不设置会报错,
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
结果如下:
[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]
注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。
ataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
结果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
可以看到实际并没有shuffle
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加10
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
结果
[16 17 18 19]
[10 11 12 13 14 15]
【1】Tensorflow datasets.shuffle repeat batch方法
【2】TensorFlow全新的数据读取方式:Dataset API入门教程
【3】Module: tf.data