tensorflow dataset使用

tensorflow 中的dataset是一个基类,可以用来处理数据的输入。

数据生成

可以用以下方法生成dataset:

  • from_tensor_slices()
    将输入的张量、元组、字典进行切片处理。切片发生在最外层的维度。比如将(5,2)形状的张量输入,得到的是5个(2)形状的张量:
import tensorflow as tf
s = tf.constant(range(1,11))
s = tf.reshape(s,(5,2))
t = tf.data.Dataset.from_tensor_slices(s)
print(list(t.as_numpy_iterator()))

结果:

[array([1, 2]), array([3, 4]), array([5, 6]), array([7, 8]), array([ 9, 10])]

对于有监督的学习,该函数还可以生成(特征,标签)形式的数据对:

import tensorflow as tf
feature = tf.constant(range(1,11))
feature = tf.reshape(s,(5,2))
label = tf.constant(range(1,6))

t = tf.data.Dataset.from_tensor_slices((feature,label))
print(list(t.as_numpy_iterator()))

结果:

[(array([1, 2]), 1), (array([3, 4]), 2), (array([5, 6]), 3), (array([7, 8]), 4), (array([ 9, 10]), 5)]
  • from_tensors()
    与from_tensor_slices()类似,区别是没有切分最外层。from_tensors()没有对张量进行任何处理,而from_tensor_slices()进行了切分。
  • from_generator()
    使用生成器的方法生成数据。不过一般很少使用。

数据读取

数据的读取会使用tensorflow_datasets中的鸢尾花数据。tensorflow_datasets需要单独安装。

  • list读入全部数据
import tensorflow_datasets as ds
import tensorflow as tf

data = ds.load('iris',split="train",as_supervised=True)
print(list(data.as_numpy_iterator()))

结果:(部分)

[(array([5.1, 3.4, 1.5, 0.2], dtype=float32), 0), (array([7.7, 3. , 6.1, 2.3], dtype=float32), 2),

此方法将数据一次性读入,并输出为Python原生数组。

  • take/skip
    take()是从数据源中,一次读入制定的数据。这种读取是按照顺序读取的。skip()是跳过指定数量的记录。skip()的用途是在生成测试数据集时,跳过训练数据集。比如有10条数据,使用take(7)生成了训练数据集,使用skip(7).take(3)生成测试数据集。
    使用take()和skip()时,在真正读入数据的时候,会产生以下告警信息:
2021-01-03 06:56:58.863745: W tensorflow/core/kernels/data/cache_dataset_ops.cc:794] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

按照提示的方法,并不能解决问题。猜测产生告警的原因是TensorFlow不知道要数据源有多大,无法确定cache内容。一个解决方法是在调用take前,一次性读入全部的数据:

import tensorflow_datasets as ds
import tensorflow as tf

data = ds.load('iris',split="train",as_supervised=True)
l = list(data.as_numpy_iterator())
t = data.take(1)

for i in t:
    print(i)

结果:

(<tf.Tensor: shape=(4,), dtype=float32, numpy=array([5.1, 3.4, 1.5, 0.2], dtype=float32)>, <tf.Tensor: shape=(), dtype=int64, numpy=0>)
  • prefetch()
    prefetch()的作用是提高数据读取的效率。在程序执行中,当 CPU 为计算准备数据时,计算设备处于闲置状态;当计算设备执行训练步骤时,CPU 处于闲置状态。因此,单个训练步骤 的时间等于 CPU 准备数据的时间 + 计算设备执行训练 step 的时间。
    并行的方法可以提高效率。做法是将训练步骤中的数据准备和模型执行 “并行”。当计算设备在执行第 N 个训练步骤时,CPU 为第 N+1 个训练步骤准备数据。通过这种方法,可以减少程序执行的时间。
    prefetch()提供了软数据管道机制。可以在处理当前数据时准备后面的数据。会提高延迟和吞吐量,代价是使用额外的内存来存储预取的元素。
    关于高性能数据输入管道的文章
    优化训练数据

数据转换

Dataset提供了多种数据转换方法:

  • map: 将转换函数映射到数据集每一个元素
  • batch : 构建批次,每次读入批次。比原始数据增加一个维度。
  • shuffle: 打乱数据顺序。可以避免输入重复数据到神经网络中,避免过拟合。
  • filter: 过滤掉某些元素
  • concatenate: 将两个Dataset纵向连接
  • zip: 将两个长度相同的Dataset横向铰合。

你可能感兴趣的:(Python,TensorFlow,python,tensorflow)