tensorflow2数据读取P2: tf.data.Dataset.from_generator通过自定义的python生成器构造Dataset

上一篇文章:tensorflow2数据读取P1: numpy array中,提到model.fit支持输入5类数据。分别是numpy array, tensor, 字典(值为numpy array或tensor), tf.data创建的Dataset, 生成器类型。

其中tensor, 字典与numpy array区别不是很大。在必须使用这2中数据类型时,做一些类型转换即可。这篇文章想要说说tf.data创建的Dataset类型作为输入。学会这种方法,可以处理数据量很大的任务。因为,一些方法创建的Dataset类型,类似一个生成器,可以一点一点地读取数据。而不是一次性将所有的数据都读取到内存中。在数据量很大的任务中,一次性将所有数据都读入内存是不太可能的。

这里主要讲解以下3中创建Dataset的方法

  • tf.data.Dataset.from_generator通过自定义的python生成器构造Dataset
  • tf.data.Dataset.from_generator通过preprocessing.image.ImageDataGenerator构造Dataset, 此方法比较适合图像任务
  • tf.data.TFRecordDataset创建Dataset

tf.data.Dataset.from_generator通过自定义的python生成器构造Dataset

python生成器可以简单理解成,一种用yield关键字返回值的一种函数。生成器有点像挤牙膏,你调用它一次,它的任务就往前推进一点。这就很适合处理大量数据时,每次只读取需要的那部分数据的情景。
假设有一批图片数据,我们可以自定义一个生成器,yield图片路径,然后再读取图片

import os
import tensorflow as tf
import glob
# 自定义一个读取路径的生成器
def filepath_generator(inroot,filetype):
    inroot = inroot.decode('utf-8') # 这里用decode是因为这个生成器的参数传入tf.data.Dataset.from_generator之后,字符串会变成bytes类型
    filetype = filetype.decode('utf-8')
    inroot = os.path.join(inroot,'*.'+filetype)
    filelist = glob.glob(inroot)
    for filepath in range(len(filelist)):
        yield filelist[filepath]

然后使用tf.data.Dataset.from_generator构建dataset

inroot = r'imgs'
filetype = 'jpeg'           
dataset = tf.data.Dataset.from_generator(filepath_generator,(tf.string),args=[inroot,filetype])   

在这里插入图片描述
tf.data.Dataset.from_generator函数接收4个参数,分别是generator(生成器),output_types(生成器输出的数据类型),output_shapes(生成器输出的数据形状,可选参数),args(调用生成器需要的参数)
注意
送入的generator(生成器)必须是可调用的,例如这个例子里面,传入filepath_generator函数名,这个就是可调用的。但是,如果生成一个mygenerator = filepath_generator(inroot,filetype), 再将mygenerator传入tf.data.Dataset.from_generator(mygenerator, (tf.string)), 就会报generator必须是可调用的错误。

为了使得传入的generator是可调用的,将函数改成这种方式,看起来很别扭。另一种方式是使用lambda函数,示例如下。注意这段代码lambda: filepath_generator(inroot,filetype)

import os
import tensorflow as tf
import glob

def filepath_generator(inroot,filetype):
    inroot = os.path.join(inroot,'*.'+filetype)
    filelist = glob.glob(inroot)
    for filepath in range(len(filelist)):
        yield filelist[filepath]

inroot = r'imgs'
filetype = 'jpeg'           

dataset = tf.data.Dataset.from_generator(
    lambda: filepath_generator(inroot,filetype), (tf.string)
)   

for one_batch in dataset.batch(3).repeat(5).shuffle(buffer_size=100):
    imgs = []
    for i in one_batch.numpy():
        print(i.decode('utf-8'))
        img = cv2.imread(i.decode('utf-8'))
        imgs.append(img)

使用的代码在https://gitee.com/xxjdxmt/learning

你可能感兴趣的:(tensorflow2使用,python,tensorflow,机器学习)