Tensorflow 数据读取 tf.data.Dataset API 相关介绍

介绍

tf.1.4及以后新出的tf.data.Dataset API 中,使用的数据读取方式有点类似于pytorch中的Dataloader,大大简化了数据读取。下面是代码实例。


# coding=utf-8
import os
import numpy as np
import glob

import tensorflow as tf
import tensorflow.contrib.eager as tfe

"""数据读取: Dataset API的介绍"""
"""
1. Dataset API 支持tensorflow新出的Eager模式
            Eager模式:迭代时可直接取值,而不是tensor。但在tf 1.4的标准版中,没有eager模式,而是在nightly version
2. 通过Dataset类可以实例化出一个Iterator
3. Dataset 可以看成是相同类型元素的有序列表。这里的元素可以是向量,字符串,图片,或者tuple,dict等
4. 从Dataset中取出元素:
            需要实例化一个Interator,然后对Iterator进行迭代
5. Dataset支持一类特殊的操作: Transformation. 一个Dataset通过Transformation变成一个新的Dataset。
    我们可以通过Transformation完成 数据变换, 打乱, 组成batch, 生成epoch 等操作
    常用的Transformation:
                (1) map
                (2) batch
                (3) shuffle
                (4) repeat
6. dataset的创建方法:
    (1) tf.data.Dataset.from_tensor_slices
    (2) tf.data.TextLineDataset(): 输入是一个文件列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。
                                    可以用这个函数来读取csv文件
    (3) tf.data.FixedLengthRecordDataset(): 通常用来读取以二进制形式保存的文件,如CIFAR10数据集
    (4) tf.data.TFRecordDataset(): 用来读取tfrecord文件,dataset中的每一个元素就是一个TFExample
"""


def eager_dataset():
    """
    以eager模式读取数据集
    :return: 
    """
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
    iterator = tfe.Iterator(dataset)
    for one_element in iterator:
        print(one_element)


def non_eager_dataset():
    """
    以非eager的方式读取数据集
    :return: 
    """
    # from_tensor_slices: 切分传入Tensor的第一个维度,生成相应的dataset
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

    """非eager模式"""
    # 创建一个iterator,且是一个one shot iterator,即只能从头到尾读取一次
    iterator = dataset.make_one_shot_iterator()
    # 非Eager模式:one_element是一个tensor,而不是个实际的值
    one_element = iterator.get_next()

    # with tf.Session() as sess:
    #     for i in range(5):
    #         # 如果一个dataset中的元素被读取完了,再尝试运行sess.run(one_element),会报tf.errors.OutOfRangeError的异常
    #         print(sess.run(one_element))

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def non_eager_dataset_v2():
    dataset = tf.data.Dataset.from_tensor_slices(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def non_eager_dataset_dict_classical():
    """
    经典的图像处理类问题中,image 和 label 的组织形式: 
                    {'image': image_tensor, 'label': label_tensor}
    :return: 
    """
    # from_tensor_slices 会分别切分'a','b'中的数值,最终dataset中的一个元素类似于{'a': 1.0, 'b': dog}的形式
    dataset = tf.data.Dataset.from_tensor_slices(
        {'a': np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 'b': ['dog', 'cat', 'pig', 'monkey', 'bear']})
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


"""Transformation 相关操作"""
def map_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
    dataset = dataset.map(lambda x: x + 1)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def batch_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array(range(32)))
    # 注: batch 也支持不整除的操作
    dataset = dataset.batch(5)
    dataset = dataset.shuffle(1000)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    cnt = 0
    with tf.Session() as sess:
        try:
            while True:
                print('batch: {}, {}'.format(cnt, sess.run(one_element)))
                cnt += 1
        except tf.errors.OutOfRangeError:
            print('End')


def repeat_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array(range(10)))
    dataset = dataset.shuffle(1000)
    # repeat 的功能就是将整个数据集重复多次,主要用来处理机器学习中的epoch.
    dataset = dataset.repeat(3)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


"""一个经典的读取image和label的列子"""
def parse_function(filename, label):
    image_string = tf.read_file(filename)
    # image_decoded = tf.image.decode_image(image_string, channels=3)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize_images(image_decoded, size=(100, 100))

    return image_resized, label


def dataset_classical_example():
    batch_size = 4

    filenames_tmp = glob.glob(os.path.join('./data_samples', '*.{}'.format('jpg')))
    filenames = tf.constant(filenames_tmp)
    labels = tf.constant(range(len(filenames_tmp)))

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).repeat(3)

    iterator = dataset.make_one_shot_iterator()
    one_batch = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                batch_images, batch_labels = sess.run(one_batch)
        except tf.errors.OutOfRangeError:
            print('End')


if __name__ == '__main__':
    # non_eager_dataset_dict_classical()
    # map_fun()
    # batch_fun()

    # repeat_fun()
    dataset_classical_example()




参考链接:

tf.data.Dataset介绍
tensorflow 导入数据 官网教程


你可能感兴趣的:(深度学习)