TensorFlow2.0图片数据集的读取

本文系转载,原文链接:TensorFlow2.0图片数据集的读取

TensorFlow在1.7后引入了较高层的api如tf.data.Dataset,用于数据的读取与预处理

因此,我们在2.0里使用之前较为老的api读取图片数据集时,很不方便,而且有的方法也被弃用了,于是还是老老实实的使用 2.0里高层api罢了。
话不多说先看代码


import  os
import  tensorflow as tf
import  numpy as np
 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
 
PATH = os.path.join(os.path.dirname('./young2old-dataset/train'), './young2old-dataset/train')
#列出所以的图片的路径
train_dataset = tf.data.Dataset.list_files(PATH+'/train/A/*.jpg')
#定义加载图片的函数,用于后面的for循环进行复用
def load_image(image_file, is_train):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    return image
#定义一个迭代器,用于for循环
train_iter = iter(train_dataset)
train_data = []
for x in train_iter:
    train_data.append(load_image(x, True))
train_data = tf.stack(train_data, axis=0)
 
print('train:', train_data.shape)
#使用data.Dataset.from_tensor_slices定义一系列的tensor,用于后面的操作,比如洗牌(shuffle),分批(batch)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)
train_dataset = train_dataset.shuffle(400).batch(1)

我们使用from_tensor_slices这个方法,实际上后面需要处理图片的时候是大有裨益的,比如我们可以使用map()里面带函数,进行图片的处理等等。tf.data.Dataset下面还有很多api,当然读取数据的方式也就有很多,这里只是一种可读性不错的方式。

你可能感兴趣的:(TensorFlow2.0图片数据集的读取)