Dataset封装了很好的关于数据集的一些基本操作,在这里做一下总结。该对象的路径是:tensorflow.data.Dataset
(这是1.4版本之后的)很大程度上参考了这篇博客
同时再推荐一个特别好的博客:https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428
Tensoflow的核心数据就是tensor。可以这么理解,想要让tensoflow的数据正确的“流动”起来,那么就需要正确的匹配张量的维数。再给计算图模型填充数据的时候,使用字典的方式是最慢的,应该避免这种操作;正确的做法是把数据组合成张量,进行操作,这也是tf.data.Dataset
的核心功能之一。
该函数是核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外层的数据切出。。直接看代码理解。
假设我们现在有两组数据,分别是特征和标签,为了简化说明问题,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple
,那么我们的想法是让每个标签都恰好对应2个特征,而且像直接切片,比如:[f11, f12] [t1]
。f11
表示第一个数据的第一个特征,f12
表示第1个数据的第二个特征,t1
表示第一个数据标签。那么tf.data.Dataset.from_tensor_slices
就是做了这件事情:
import tensorflow as tf
import numpy as np
features, labels = (np.random.sample((5, 2)), # 模拟5组数据,每组数据2个标签
np.random.sample((5, 1))) # 模拟5组特征,注意两者的维数必须匹配
print((features, labels)) # 输出下组合的数据
data = tf.data.Dataset.from_tensor_slices((features, labels))
print(data) # 输出张量的信息
结果输出:
(array([[0.94509483, 0.19160528],
[0.49125608, 0.93146317],
[0.19331899, 0.59950161],
[0.8338232 , 0.71606446],
[0.23264883, 0.71179252]]), array([[0.48340206],
[0.55842171],
[0.30450086],
[0.45078316],
[0.40497981]]))
生成一个迭代器,用于便利所有的数据。一般用法如下:
tf.data.Dataset.make_one_shot_iterator.get_next()
每次列举出下一个数据集。
实例:
import tensorflow as tf
import numpy as np
data = tf.data.Dataset.from_tensor_slices(
np.array([1, 2, 3, 4, 5]))
element = data.make_one_shot_iterator().get_next() # 建立迭代器,并进行迭代操作
with tf.Session() as sess:
try:
while True:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print("Out range !")
import tensorflow as tf
import numpy as np
a = np.array(['a', 'b', 'c', 'd', 'e'])
b = np.array([1, 2, 3, 4, 5])
# 分别切分数据,以字典的形式存储
data = tf.data.Dataset.from_tensor_slices(
{
"label1": a,
"label2": b
}
)
it=data.make_one_shot_iterator().get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(it))
except tf.errors.OutOfRangeError:
print("out of range")
输出结果
{'label2': 1, 'label1': b'a'}
{'label2': 2, 'label1': b'b'}
{'label2': 3, 'label1': b'c'}
{'label2': 4, 'label1': b'd'}
{'label2': 5, 'label1': b'e'}
与python中的map作用类似,对输入的数据进行预处理操作。
import tensorflow as tf
import numpy as np
a = np.array([1, 2, 3, 4, 5])
data = tf.data.Dataset.from_tensor_slices(a)
# 注意在这里是返回的集合,原来的集合不变
data = data.map(lambda x: x ** 2)
it = data.make_one_shot_iterator().get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(it))
except tf.errors.OutOfRangeError:
print("out of range")
batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:
dataset = dataset.batch(32)
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:
dataset = dataset.shuffle(buffer_size=10000)
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:
dataset = dataset.repeat(5) # 重复5次数据
注意,必须指明重复的次数,否则会无限期的重复下去。
dataset.shuffle(1000).repeat(10).batch(32)
把数据进行1000个为单位的乱序,重复10次,生成批次为32的batch
这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。一般操作方式:
tf.data.TextLineDataset(file_path).skip(n)
读取文件,同时跳过前n行。