本文旨在介绍tf.data.Dataset中batch, repeat, shuffle以及三者的顺序问题。首先介绍了这三个函数单独作用的结果,而后给出了相互作用下的影响。
shuffle(buffsize) 用于将数据打乱,其中buffsize的大小越大,数据的混乱程度越高,因为shuffle的实现思路为:** 开辟一可容纳buffsize个数据的缓冲区,初始时将数据的前buffsize个读入缓冲区,而后随机在缓冲区里选择一个输出,同时将数据的第buffsize+1个读入缓冲区**。由此不难理解,如果buffsize很小,比如为1时就根本没有打乱。给出示例代码和结果如下:
data=tf.range(0,10)
data=tf.data.Dataset.from_tensor_slices(data)
data1=data.shuffle(5)
for i in data1:
print(i.numpy())
'''
#结果
4
0
5
2
3
1
6
8
7
9
'''
可以测试,如果多次执行代码中的输出程序,每次的打乱结果都会发生变化,但是第一个输出的值永远都是在0~4的范围之内,这是因为我们设置的buffsize=5。
repeat(count) 用于将数据重复count次,相当于我们训练时的epoch,示例代码及结果如下:
data1=data.repeat(2)
for i in data:
print(i.numpy())
'''
#结果
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
'''
batch(batch_size)用于将数据划分为多个batch,同时tensorflow中有着很好的调整功能,当最后一个batch不满足batchsize时就以当前长度输出。
data=tf.range(0,10.)[:,None]
data=tf.data.Dataset.from_tensor_slices(data)
data1=data.batch(4)
for i in data1:
print(i.numpy())
'''
#结果
[[0.]
[1.]
[2.]
[3.]]
[[4.]
[5.]
[6.]
[7.]]
[[8.]
[9.]]
可以看到最后一个batch只含有2个数据。
其结果就是对repeat后的数据进行打乱,这会使得不同epcoh间的数据被打乱,即前一个epcoh中数据未加载完,下一个epoch中数据可能插入,导致一个epoch中可能数据重复多次。
temp1=data.repeat(2).shuffle(5)
for i in temp1:
print(i.numpy())
# 结果
# [4.] [0.] [2.] [7.] [8.] [5.] [3.] [6.] [1.]
#[1.] [0.] [2.] [5.] [4.] [9.] [9.] [3.] [6.] [7.] [8.]
可以看出,在第一个epoch还未结束(未出现9时)就已经出现了下个epoch的0
先shuffle再repeat,epoch内部打乱,一定先输出完一个epoch内所有值。
temp2=data.shuffle(5).repeat(2)
for i in temp2:
print(i.numpy())
#结果
# [2.] [4.] [5.] [6.] [0.] [1.] [3.] [9.] [8.] [7.]
#[3.] [1.] [4.] [7.] [0.] [6.] [9.] [2.] [8.] [5.]
先batch再repeat,为对于batch的复制。
temp3=data.batch(4).repeat(2)
for i in temp3:
print(i.numpy())
#结果
# [[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] [9.]]
#[[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] [9.]
先repeat再batch是对于重复后数组的分组。
temp4=data.repeat(2).batch(4)
for i in temp4:
print(i.numpy())
# [[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.]
#[9.] [0.] [1.]] [[2.] [3.] [4.] [5.]] [[6.] [7.]
#[8.] [9.]]
先batch再shuffle是对不同组间的打乱
temp5=data.batch(4).shuffle(5)
for i in temp5:
print(i.numpy())
# [[4.] [5.] [6.] [7.]]
#[[0.] [1.] [2.] [3.]]
#[[8.] [9.]]
先shuffle再batch是在对打乱后的数据分组
temp6=data.shuffle(5).batch(4)
for i in temp6:
print(i.numpy())
# [[3.] [2.] [1.] [7.]]
#[[6.] [5.] [4.] [0.]]
[[9.] [8.]]
【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法
tf.data.Dataset关于batch,repeat,shuffle的讲解
#深入探究# Tensorflow.Data.shuffle 方法的实现原理和 buffer_size 参数的作用
加载自定义图片数据集到Dataset