Tensorflow2.0——“tf.data”API

“tf.data”API

  • 前言
  • 代码示例
    • 加载数据集
    • 使用数据集
  • 数据获取
    • 其他接收其他集合类
      • from_tensors()
      • from_tensor_slices()
    • 自带方法创建Dataset对象
      • range()
      • from_generator()
    • 通过读取磁盘中的文件(文本、图片等等)来创建Dataset。
  • Dataset对象常用功能函数
    • take()
    • batch()
    • padded_batch()
    • map()
    • filter()
    • shuffle()
    • repeat()
  • tf.data.Dataset常用功能函数
    • zip()

前言

高效的数据输入管道可以很大程度地提升模型性能,减少模型训练所需要的时间。数据输入管道本质是一个ELT(Extract、Transform和Load)过程:

  • Extract:从硬盘中读取数据(可以是本地的,也可以是云端的)
  • Transform:数据的预处理(如数据清洗、格式转换等)
  • Load:将处理好的数据加载到计算设备(例如CPU、GPU及TPU等)

数据输入管道一般使用CPU来执行ELT过程,GPU等其他硬件加速设备则负责模型的训练,ELT过程和模型的训练并行执行,从而提高模型训练的效率。另外ELT过程的各个步骤也都可以进行相应的优化,例如并行地读取和处理数据等。

代码示例

这里使用的是一个花朵图片的数据集,如图2-13所示,除一个License文件外,主要是五个分别存放着对应类别花朵图片的文件夹,其中“daisy(雏菊)”文件夹中有633张图片,“dandelion(蒲公英)”文件夹中有898张图片,“roses(玫瑰)”文件夹中有641张图片,“sunflowers(向日葵)”文件夹中有699张图片,“tulips(郁金香)”文件夹中有799张图片。
Tensorflow2.0——“tf.data”API_第1张图片

加载数据集

  • 思路
  1. 准备要加载的图片路径,每个图片对应的类别,类标(将类别名映射成数字)
import tensorflow as tf
import pathlib
print(tf.__version__)   #2.0.0
# 获取当前路径
data_root= pathlib.Path.cwd()   # PosixPath('/home/bd/zhaoping/tensorflows')
# 获取指定目录下的的文件路径,返回一个生成器对象,利用list转为列表对象,列表每个元素是PosixPath对象
all_image_path = list(data_root.glob('flower*/*/*'))  #返回五个文件夹内所有图片的路径对象
#all_image_path[0]   PosixPath('/home/bd/zhaoping/tensorflows/flower_photos/sunflowers/14460081668_eda8795693_m.jpg')
# 将每个PosixPath对象转为str对象
all_image_path = [str(i) for i in all_image_path]
# all_image_path[0]  '/home/bd/zhaoping/tensorflows/flower_photos/sunflowers/14460081668_eda8795693_m.jpg'
# 获取图片的类别名字
label_name  = [i.name for i in data_root.glob('flower*/*/') if i.is_dir()] 
#label_name  ['sunflowers', 'dandelion', 'tulips', 'daisy', 'roses']
#将类标签转为数值型的字典
label_to_index = dict((j,i) for i,j in enumerate(label_name))
# label_to_index  {'daisy': 3, 'dandelion': 1, 'roses': 4, 'sunflowers': 0, 'tulips': 2}
#获取所有图片的所属类别
all_image_label = [label_to_index[pathlib.Path(i).parent.name] for i in all_image_path]
# all_image_label[:2]   [0, 0]

  1. 使用 tf.data.Dataset.from_tensor_slices() 函数进行加载,将图片路径转为tensor对象。
#构建图片路径的数据集
path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)  #
  1. 使用 map() 函数进行图片预处理。
def load_and_process_image(path):
    """
    path:一张图片的路径
    功能:读取图片,将图片大小统一,并对像素值做归一化处理
    return:图片对象,即矩阵
    """
    #读取图片
    image = tf.io.read_file(path)  #返回tensorflow.python.framework.ops.EagerTensor对象,里面的值是二进制,需要解析
    #解析图片,即转成numpy矩阵(三维矩阵)
    image = tf.image.decode_jpeg(image,channels=3)  #返回还是EagerTensor对象,但其值转为图像对应的值  dtype=uint8
    #统一图片大小
    image = tf.image.resize(image,[192,192])# dtype=float32
    #对像素值做归一化处理
    image = image / 255.0    # dtype=float32
    return image
#使用AUTOTUNE自动调节管道参数
AUTOTUNE = tf.data.experimental.AUTOTUNE
#构建图片的数据集,使用map函数进行预处理
image_ds = path_ds.map(load_and_process_image,num_parallel_calls = AUTOTUNE)
  1. 建立类标数据集,并与图片数据集zip成对
#构建类标的数据集
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_label,tf.int64)) #
#tf.cast 类型转换功能,且这里直接将列表变成了Tensor对象
#将图片数据集核类标数据集压缩成对
image_label_ds = tf.data.Dataset.zip((image_ds,label_ds)) #
  1. 数据集部分数据可视化
import matplotlib.pyplot  as plt
plt.figure(figsize=(8,8))
for i,label_image in enumerate(image_label_ds.take(4)):
    plt.subplot(2,2,i+1)
    plt.imshow(label_image[0])
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.title(str(label_image[1]))

Tensorflow2.0——“tf.data”API_第2张图片

使用数据集

接下来用创建的数据集训练一个分类模型,简单起见,直接使用“tf.keras.applications”包中训练好的模型,并将其迁移到我们的花朵分类任务上来。这里使用的是“MobileNetV2”模型。

  • 思路:
  1. 使用 shuffle() 打乱数据
#统计图片的数量
image_count = len(all_image_path)  #3670
#shuffle打乱数据集
ds = image_label_ds.shuffle(buffer_size=image_count)
#
  1. 根据需要 使用 repeat() 设置是否循环迭代数据集
#让数据集重复多次
ds = ds.repeat() #
  1. 使用 batch() 函数设置 batch size 值
#设置每批次的大小
BATCH_SIZE = 32
ds = ds.batch(BATCH_SIZE) #
#通过prefetch方法使得模型的训练和数据的加载并行
ds = ds.prefetch(buffer_size = AUTOTUNE)
#
  1. 模型导入
#将训练好的模型MobileNetV2迁移到花朵分类问题中
mobile_net = tf.keras.applications.MobileNetV2(input_shape = (192,192,3),include_top = False)
#禁止训练更新该模型的参数
mobile_net.trainable = False
  1. 根据模型调整数据集

由于“MobileNetV2”模型接收的输入数据是归一化在[-1,1]之间的数据,而之前数据预处理是将范围设为[0,1],所以需要将数据映射到[-1,1]。

def change_range(image,label):
    return image*2-1,label
kears_ds = ds.map(change_range) #
  1. 根据数据集调整模型

由于预训练好的“MobileNetV2”返回的数据维度为“(32,6,61280)”,其中“32”是一个批次(Batch)数据的大小,“6,6”代表输出的特征图的大小为6×6,“1280”代表该层使用了1280个卷积核。为了适应花朵分类任务,需要在“MobileNetV2”返回数据的基础上再增加两层网络层。全局平均池化(Global Average Pooling,GAP)是对每一个特征图求平均值,将该平均值作为该特征图池化后的结果,因此经过该操作后数据的维度变为(32,1280)。由于花朵分类任务是一个5分类的任务,因此需要再使用一个全连接(Dense),将维度变为(32,5)。

model = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(), #全局平均池化层
    tf.keras.layers.Dense(len(label_name)) #全连接层
])
#编译一下模型,同时指定使用的优化器和损失函数
model.compile(optimizer = tf.keras.optimizers.Adam(),loss = 'sparse_categorical_crossentropy',metrics = ["accuracy"])
#可以输出模型各层的参数概况
model.summary()

Tensorflow2.0——“tf.data”API_第3张图片
7. 训练模型

#最后使用“model.fit”训练模型:
model.fit(ds,epochs=1,steps_per_epoch=10)

这里参数“epochs”指定需要训练的回合数,“steps_per_epoch”代表每个回合要取多少个批次数据,通常“steps_per_epoch”的大小等于我们数据集的大小除以批次的大小后上取整。

数据获取

其他接收其他集合类

通过接收其他类型的集合类对象创建Dataset对象( Dataset对象属于可迭代对象, 可通过循环进行遍历)。
这里所说的集合类型对象包含Python内置的list、tuple,numpy中的ndarray等等。这种创建Dataset对象的方法大多通过from_tensors()和from_tensor_slices()两个方法实现。

from_tensors()

from_tensors()方法接受一个集合类型对象作为参数,返回值为一个TensorDataset类型对象,对象内容、shape因传入参数类型而异。

  • 列表作为直接参数

当接收参数为list或Tensor对象时,返回的情况是一样的,因为TensorFlow内部会将list先转为Tensor对象,然后实例化一个Dataset对象。
注意:如果传入的是元组,一定要先转为array、list或tensor对象,不然效果是另一种情况

a = [0,1,2,3]  #列表
dataset = tf.data.Dataset.from_tensors(a)
dataset_array = tf.data.Dataset.from_tensors(np.array(a))
dataset_tensor =  tf.data.Dataset.from_tensors(tf.constant(a))
next(iter(dataset))
#
  • 元组作为参数

元组作为参数,Dataset内部内容为一个tuple,tuple的元素是原来tuple元素转换为的Tensor对象。

a = (0,1,2,3)  #元组
dataset = tf.data.Dataset.from_tensors(a)
next(iter(dataset))

(,
,
,
)

from_tensor_slices()

from_tensor_slices()方法返回一个TensorSliceDataset类对象,TensorSliceDataset对象比from_tensors()方法返回的TensorDataset对象支持更加丰富的操作,例如batch操作等,因此在实际应用中更加广泛。

  • 列表作为参数

当传入一个list时,时将list中元素逐个转换为Tensor对象然后依次放入Dataset中,所以Dataset中有多个Tensor对象。

a = [0,1,2,3]  #列表
dataset = tf.data.Dataset.from_tensor_slices(a)
dataset_array = tf.data.Dataset.from_tensor_slices(np.array(a))
dataset_tensor =  tf.data.Dataset.from_tensor_slices(tf.constant(a))
for n,i in enumerate(dataset):
    print(n ,'-->' ,i)

0 --> tf.Tensor(0, shape=(), dtype=int32)
1 --> tf.Tensor(1, shape=(), dtype=int32)
2 --> tf.Tensor(2, shape=(), dtype=int32)
3 --> tf.Tensor(3, shape=(), dtype=int32)

  • 元组作为参数

当传入参数为tuple时,会将tuple中各元素转换为Tensor对象,然后将第一维度对应位置的切片进行重新组合成一个tuple依次放入到Dataset中,所以在返回的Dataset中有多个tuple。

a = [0,1,2,3]  #列表
b = [4,5,6,7]
dataset = tf.data.Dataset.from_tensor_slices((a,b))
for n,i in enumerate(dataset):
    print(n ,'-->' ,i)

0 --> (, )
1 --> (, )
2 --> (, )
3 --> (, )

自带方法创建Dataset对象

range()

range()方法是Dataset内部定义的一个的静态方法,可以直接通过类名调用。另外,Dataset中的range()方法与Python本身内置的range()方法接受参数形式是一致的,可以接受range(begin)、range(begin, end)、range(begin, end, step)等多种方式传参。

dataset1 = tf.range(0,10,2)
#
type(dataset1)
#tensorflow.python.data.ops.dataset_ops.RangeDataset

range()方法创建的Dataset对象内部每一个元素都以Tensor对象的形式存在,可以通过numpy()方法访问真实值。

for i in dataset1:
    print(i.numpy())

0
2
4
6
8

from_generator()

如果你觉得range()方法不够灵活,功能不够强大,那么你可以尝试使用from_generator()方法。from_generator()方法接收一个可调用的生成器函数最为参数,在遍历from_generator()方法返回的Dataset对象过程中不断生成新的数据,减少内存占用,这在大数据集中很有用。

dataset1 = tf.range(0,10,2)
def count(stop):
    i = 0
    while i < stop:
        print("第{}次调用".format(i+1))
        yield i
        i += 1
dataset1 = tf.data.Dataset.from_generator(count,args = [3],output_types = tf.int8,output_shapes=())
a = iter(dataset1)
next(a)
#第1次调用
#
for i in dataset1:
    print(i)
    print(i.numpy())

第1次调用
tf.Tensor(0, shape=(), dtype=int8)
0
第2次调用
tf.Tensor(1, shape=(), dtype=int8)
1
第3次调用
tf.Tensor(2, shape=(), dtype=int8)
2

通过读取磁盘中的文件(文本、图片等等)来创建Dataset。

tf.data中提供了TextLineDataset、TFRecordDataset等对象来实现此功能。后续用到会专门讲一篇博客。

Dataset对象常用功能函数

take()

功能:随机采样,用于返回一个新的Dataset对象,新的Dataset对象包含的数据是原Dataset对象的子集。
参数:
1、count:整型,用于指定前count条数据用于创建新的Dataset对象,如果count为-1或大于原Dataset对象的size,则用原Dataset对象的全部数据创建新的对象。

batch()

功能:将Dataset中连续的数据分割成批。
参数:
1、batch_size:在单个批次中合并的此数据集的数量。
2、drop_remainder:如果最后一批的数据量少于指定的batch_size,是否抛弃最后一批,默认为False,表示不抛弃。

为什么在训练模型时要将Dataset分割成一个个batch呢?

  • 对于小数据集是否使用batch关系不大,但是对于大数据集如果不分割成batch意味着将这个数据集一次性输入模型中,容易造成内存爆炸。
  • 通过并行化提高内存的利用率。就是尽量让你的GPU满载运行,提高训练速度。
  • 单个epoch的迭代次数减少了,参数的调整也慢了,假如要达到相同的识别精度,需要更多的epoch。
  • 适当Batch Size使得梯度下降方向更加准确。

padded_batch()

功能: batch()的进阶版,可以对shape不一致的连续元素进行分批。
参数:
1、batch_size:在单个批次中合并的此数据集的数量。
2、drop_remainder:如果最后一批的数据量少于指定的batch_size,是否抛弃最后一批,默认为False,表示不抛弃。
3、padded_shapes:tf.TensorShape或其他描述tf.int64矢量张量对象,表示在批处理之前每个输入元素的各个组件应填充到的形状。如果参数中有None,则表示将填充为每个批次中该尺寸的最大尺寸。
4、padding_values:要用于各个组件的填充值。默认值0用于数字类型,字符串类型则默认为空字符。

map()

功能: 以dataset中每一位元素为参数执行pap_func()方法,这一功能在数据预处理中修改dataset中元素是很实用。
参数:
1、map_func:回调方法。

filter()

功能:对Dataset中每一个执行指定过滤方法进行过滤,返回过滤后的Dataset对象.。
参数:
1、predicate:过滤方法,返回值必须为True或False

dataset = tf.data.Dataset.range(10)
def filt(a):
    if a % 2 == 0:
        return True
    else:
        return False
data = dataset.filter(filt)  #0,2,4,6,8

shuffle()

功能:随机打乱数据
参数:
1、buffer_size:缓冲区大小,姑且认为是混乱程度吧,当值为1时,完全不打乱,当值为整个Dataset元素总数时,完全打乱。
2、seed:将用于创建分布的随机种子。
3、reshuffle_each_iteration:如果为true,则表示每次迭代数据集时都应进行伪随机重排,默认为True。

repeat()

功能:对Dataset中的数据进行重复,以创建新的Dataset。
参数:
1、count:重复次数,默认为None,表示不重复,当值为-1时,表示无限重复。

dataset = tf.data.Dataset.range(3)
data = dataset.repeat(2).shuffle(6)
for i in data:
    print(i.numpy())
# 0 2 2 1 0 1

tf.data.Dataset常用功能函数

zip()

a = tf.data.Dataset.from_tensors([1,2,3])
b = tf.data.Dataset.from_tensors(['a','b','c'])
c = tf.data.Dataset.zip((a,b))
#

你可能感兴趣的:(Tensorflow)