tf.data.Dataset.from_tensors()
tf.data.Dataset.from_tensor_slices()
需要的是列表或者其他作为输入tf.data.TFRecordDataset()
tf.data.Dataset.from_generator
用于将Python的生成器转变为Dataset#定义一个count生成器
def count(stop):
i = 0
while i<stop:
yield i
i += 1
for n in count(5):
print(n)
>>>0
>>>1
>>>2
>>>3
>>>4
ds_counter = tf.data.Dataset.from_generator(count, args=[25],
output_types=tf.int32, output_shapes = (), )
#这样就把生成器转为了数据集
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
3、利用 tf.image.ImageDataGenerator
进行数据增强 :
首先创建一个ImageDataGenerator
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
然后使用从路径读取图片不断产生数据,形成的生成器用法与ImageDataGenerator类的方法一致。比如使用flow_from_directory()
images, labels = next(img_gen.flow_from_directory(flowers))
#这里的flow_from_directory就和from keras.preprocessing.image import ImageDataGenerator中的用法一致了,直接就是利用不同的文件生成标签等
>>>Found 3670 images belonging to 5 classes.
最后调用from_genrtor()形成Dataset:
ds = tf.data.Dataset.from_generator(
lambda: img_gen.flow_from_directory(flowers),
output_types=(tf.float32, tf.float32),
output_shapes=([32,256,256,3], [32,5])
)
ds.element_spec
>>>(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
4、使用Dataset.map(f)进行数据预处理
重点理解:Dataset.map(f)
这个函数通过对输入数据集的每个元素应用给定的函数来生成一个新的数据集。它是基于python中的map()函数的,但一定要注意的是与python中的有很大不同。
一定注意的是这个函数的输入要求是tf.Tensor
,返回的输出也是张量对象tf.Tensor
。这里很容易出错,经常直接利用array来作为输入。
最重要的是!!!: 这个函数他是对整个Dataset进行映射,相当于无论原始的Dataset有多少个张量在其中都是直接一对一直接应用函数转换为另一个Dataset。所以映射函数f
一般都是TensorFlow的操作,就是一般都是tf.XX
这样的函数。
它的实现是使用标准的TensorFlow操作将一个元素转换为另一个元素。
例子:
#利用文件名称列表创建一个Dataset
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
#定义一个预处理函数,从文件中读取图像,将其解码为稠密张量,并调整其大小
def parse_image(filename):
parts = tf.strings.split(filename, os.sep)
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.io.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [128, 128])
return image, label
#使用map进行处理
images_ds = list_ds.map(parse_image)
for image, label in images_ds.take(2):
show(image, label)
参考:tf.data: Build TensorFlow input pipelines