tf.1.4及以后新出的tf.data.Dataset API 中,使用的数据读取方式有点类似于pytorch中的Dataloader,大大简化了数据读取。下面是代码实例。
# coding=utf-8
import os
import numpy as np
import glob
import tensorflow as tf
import tensorflow.contrib.eager as tfe
"""数据读取: Dataset API的介绍"""
"""
1. Dataset API 支持tensorflow新出的Eager模式
Eager模式:迭代时可直接取值,而不是tensor。但在tf 1.4的标准版中,没有eager模式,而是在nightly version
2. 通过Dataset类可以实例化出一个Iterator
3. Dataset 可以看成是相同类型元素的有序列表。这里的元素可以是向量,字符串,图片,或者tuple,dict等
4. 从Dataset中取出元素:
需要实例化一个Interator,然后对Iterator进行迭代
5. Dataset支持一类特殊的操作: Transformation. 一个Dataset通过Transformation变成一个新的Dataset。
我们可以通过Transformation完成 数据变换, 打乱, 组成batch, 生成epoch 等操作
常用的Transformation:
(1) map
(2) batch
(3) shuffle
(4) repeat
6. dataset的创建方法:
(1) tf.data.Dataset.from_tensor_slices
(2) tf.data.TextLineDataset(): 输入是一个文件列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。
可以用这个函数来读取csv文件
(3) tf.data.FixedLengthRecordDataset(): 通常用来读取以二进制形式保存的文件,如CIFAR10数据集
(4) tf.data.TFRecordDataset(): 用来读取tfrecord文件,dataset中的每一个元素就是一个TFExample
"""
def eager_dataset():
"""
以eager模式读取数据集
:return:
"""
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = tfe.Iterator(dataset)
for one_element in iterator:
print(one_element)
def non_eager_dataset():
"""
以非eager的方式读取数据集
:return:
"""
# from_tensor_slices: 切分传入Tensor的第一个维度,生成相应的dataset
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
"""非eager模式"""
# 创建一个iterator,且是一个one shot iterator,即只能从头到尾读取一次
iterator = dataset.make_one_shot_iterator()
# 非Eager模式:one_element是一个tensor,而不是个实际的值
one_element = iterator.get_next()
# with tf.Session() as sess:
# for i in range(5):
# # 如果一个dataset中的元素被读取完了,再尝试运行sess.run(one_element),会报tf.errors.OutOfRangeError的异常
# print(sess.run(one_element))
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('End')
def non_eager_dataset_v2():
dataset = tf.data.Dataset.from_tensor_slices(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('End')
def non_eager_dataset_dict_classical():
"""
经典的图像处理类问题中,image 和 label 的组织形式:
{'image': image_tensor, 'label': label_tensor}
:return:
"""
# from_tensor_slices 会分别切分'a','b'中的数值,最终dataset中的一个元素类似于{'a': 1.0, 'b': dog}的形式
dataset = tf.data.Dataset.from_tensor_slices(
{'a': np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 'b': ['dog', 'cat', 'pig', 'monkey', 'bear']})
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('End')
"""Transformation 相关操作"""
def map_fun():
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('End')
def batch_fun():
dataset = tf.data.Dataset.from_tensor_slices(np.array(range(32)))
# 注: batch 也支持不整除的操作
dataset = dataset.batch(5)
dataset = dataset.shuffle(1000)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
cnt = 0
with tf.Session() as sess:
try:
while True:
print('batch: {}, {}'.format(cnt, sess.run(one_element)))
cnt += 1
except tf.errors.OutOfRangeError:
print('End')
def repeat_fun():
dataset = tf.data.Dataset.from_tensor_slices(np.array(range(10)))
dataset = dataset.shuffle(1000)
# repeat 的功能就是将整个数据集重复多次,主要用来处理机器学习中的epoch.
dataset = dataset.repeat(3)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('End')
"""一个经典的读取image和label的列子"""
def parse_function(filename, label):
image_string = tf.read_file(filename)
# image_decoded = tf.image.decode_image(image_string, channels=3)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, size=(100, 100))
return image_resized, label
def dataset_classical_example():
batch_size = 4
filenames_tmp = glob.glob(os.path.join('./data_samples', '*.{}'.format('jpg')))
filenames = tf.constant(filenames_tmp)
labels = tf.constant(range(len(filenames_tmp)))
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(parse_function)
dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).repeat(3)
iterator = dataset.make_one_shot_iterator()
one_batch = iterator.get_next()
with tf.Session() as sess:
try:
while True:
batch_images, batch_labels = sess.run(one_batch)
except tf.errors.OutOfRangeError:
print('End')
if __name__ == '__main__':
# non_eager_dataset_dict_classical()
# map_fun()
# batch_fun()
# repeat_fun()
dataset_classical_example()
tf.data.Dataset介绍
tensorflow 导入数据 官网教程