TensorFlow2学习——tf.data模块

tf.data API的使用

tf.data.Dataset:表示一系列元素,其中每个元素包含一个或多个 Tensor 对象。例如,在图片管道中,一个元素可能是单个训练样本,具有一对表示图片数据和标签的张量。可以通过两种不同的方式来创建数据集。
直接从 Tensor 创建 Dataset(例如 Dataset.from_tensor_slices());当然 Numpy 也是可以的,TensorFlow 会自动将其转换为 Tensor。
通过对一个或多个 tf.data.Dataset 对象来使用变换(例如 Dataset.batch())来创建 Dataset
已知文件名称和标签,用data保存每一个文件的地址,用label保存每一文件对应的标签。data和label都是列表,形式如 data = [‘xxxx.jpg’,‘qqqq.jpg’,…]; label = [0,2,3,4,1,…]

import tensorflow as tf
import os
file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data= [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
print(data)
print(len(label))

for i in os.listdir(file_path) 代表的是迭代出这个路径下的所有文件
os.path.join(file_path,i)代表的是这两个路径做拼接
以上代码实现了图片对应的label(这些label全是0)

tf.data.Dataset.from_tensor_slices

dataset = tf.data.Dataset.from_tensor_slices((data,label))
print(datset)
#  <DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.int32)>

在这里将数据和标签相对应起来,构建了一个dataset
有几个特定的函数需要注意:

batch():用一个整型数字作为参数,描述了一个batch的batch size。图片太多可能一次放不下进行训练,分batchsize个批次进行训练。
repeat():参数同样是一个整型数字,描述了整个dataset需要重复几次(epoch),如果没有参数,则重复无限次,设置参数counts=3等等。
shuffle():顾名思义,数据的乱序
map():常常用作预处理,图像解码等操作,参数是一个函数句柄,dataset的每一个元素都会经过这个函数的到新的tensor代替原来的元素。

具体关于tensorflow的数据读取机制可以参考知乎这篇文章:
链接: 十图详解tensorflow数据读取机制(附代码).

你可能感兴趣的:(深度学习)