tensorflow 中的dataset是一个基类,可以用来处理数据的输入。
可以用以下方法生成dataset:
import tensorflow as tf
s = tf.constant(range(1,11))
s = tf.reshape(s,(5,2))
t = tf.data.Dataset.from_tensor_slices(s)
print(list(t.as_numpy_iterator()))
结果:
[array([1, 2]), array([3, 4]), array([5, 6]), array([7, 8]), array([ 9, 10])]
对于有监督的学习,该函数还可以生成(特征,标签)形式的数据对:
import tensorflow as tf
feature = tf.constant(range(1,11))
feature = tf.reshape(s,(5,2))
label = tf.constant(range(1,6))
t = tf.data.Dataset.from_tensor_slices((feature,label))
print(list(t.as_numpy_iterator()))
结果:
[(array([1, 2]), 1), (array([3, 4]), 2), (array([5, 6]), 3), (array([7, 8]), 4), (array([ 9, 10]), 5)]
数据的读取会使用tensorflow_datasets中的鸢尾花数据。tensorflow_datasets需要单独安装。
import tensorflow_datasets as ds
import tensorflow as tf
data = ds.load('iris',split="train",as_supervised=True)
print(list(data.as_numpy_iterator()))
结果:(部分)
[(array([5.1, 3.4, 1.5, 0.2], dtype=float32), 0), (array([7.7, 3. , 6.1, 2.3], dtype=float32), 2),
此方法将数据一次性读入,并输出为Python原生数组。
2021-01-03 06:56:58.863745: W tensorflow/core/kernels/data/cache_dataset_ops.cc:794] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
按照提示的方法,并不能解决问题。猜测产生告警的原因是TensorFlow不知道要数据源有多大,无法确定cache内容。一个解决方法是在调用take前,一次性读入全部的数据:
import tensorflow_datasets as ds
import tensorflow as tf
data = ds.load('iris',split="train",as_supervised=True)
l = list(data.as_numpy_iterator())
t = data.take(1)
for i in t:
print(i)
结果:
(<tf.Tensor: shape=(4,), dtype=float32, numpy=array([5.1, 3.4, 1.5, 0.2], dtype=float32)>, <tf.Tensor: shape=(), dtype=int64, numpy=0>)
Dataset提供了多种数据转换方法: