我们可以从以下七种数据结构中构建数据通道:
Numpy array
,Pandas DataFrame
,Python generator
,csv文件
,文本文件
,文件路径
,tfrecords文件
由于从tfrecord
文件中构建数据通道比较复杂,所以接下来就只介绍前面六种情况。下面在介绍的时候也会指出使用了的tf.data
下的五种API,其中 Numpy array
和Pandas DataFrame
使用的是同一种API。
tf.data.Dataset.from_tensor_slices
)import tensorflow as tf
import numpy as np
from sklearn import datasets
iris = datasets.load_iris()
ds1 = tf.data.Dataset.from_tensor_slices((iris["data"],iris["target"]))
for features,label in ds1.take(5):
print(features,label)
tf.Tensor([5.1 3.5 1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.9 3. 1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.7 3.2 1.3 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([4.6 3.1 1.5 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor([5. 3.6 1.4 0.2], shape=(4,), dtype=float64) tf.Tensor(0, shape=(), dtype=int64)
tf.data.Dataset.from_tensor_slices
)import tensorflow as tf
from sklearn import datasets
import pandas as pd
iris = datasets.load_iris()
dfiris = pd.DataFrame(iris["data"],columns = iris.feature_names)
##dfiris.to_dict("list")会将DataFrame转换成字典,
##生成的数据集中也会以字典为形式设置元素
ds2 = tf.data.Dataset.from_tensor_slices((dfiris.to_dict("list"),iris["target"]))
for features,label in ds2.take(1):
print(features,label)
{'sepal length (cm)': <tf.Tensor: shape=(), dtype=float32, numpy=5.1>,
'sepal width (cm)': <tf.Tensor: shape=(), dtype=float32, numpy=3.5>,
'petal length (cm)': <tf.Tensor: shape=(), dtype=float32, numpy=1.4>,
'petal width (cm)': <tf.Tensor: shape=(), dtype=float32, numpy=0.2>
}
tf.Tensor(0, shape=(), dtype=int64)
tf.data.Dataset.from_generator
)# 从Python generator构建数据管道
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 定义一个从文件中读取图片的generator
image_generator = ImageDataGenerator(rescale=1.0/255)
.flow_from_directory(
"./data/cifar2/test/",
target_size=(32, 32),
batch_size=20,
class_mode='binary')
classdict = image_generator.class_indices
print(classdict)
def generator():
for features,label in image_generator:
yield (features,label)
ds3 = tf.data.Dataset.from_generator(generator,output_types=(tf.float32,tf.int32))
Found 2000 images belonging to 2 classes.
{'airplane': 0, 'automobile': 1}
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.figure(figsize=(6,6))
for i,(img,label) in enumerate(ds3.unbatch().take(9)):
ax=plt.subplot(3,3,i+1)
ax.imshow(img.numpy())
ax.set_title("label = %d"%label)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
tf.data.experimental.make_csv_dataset
)# 从csv文件构建数据管道
ds4 = tf.data.experimental.make_csv_dataset(
file_pattern = ["./data/titanic/train.csv",
"./data/titanic/test.csv"],
batch_size=3,
label_name="Survived",
na_value="",
num_epochs=1,
ignore_errors=True)
for data,label in ds4.take(1):
print(data,label)
OrderedDict([
('PassengerId', <tf.Tensor: shape=(3,), dtype=int32,
numpy=array([136, 48, 805], dtype=int32)>),
('Pclass', <tf.Tensor: shape=(3,), dtype=int32,
numpy=array([2, 3, 3], dtype=int32)>),
('Name', <tf.Tensor: shape=(3,), dtype=string,
numpy=array([b'Richard, Mr. Emile',
b"O'Driscoll, Miss. Bridget",
b'Hedman, Mr. Oskar Arvid'], dtype=object)>),
('Sex', <tf.Tensor: shape=(3,), dtype=string,
numpy=array([b'male', b'female', b'male'], dtype=object)>),
('Age', <tf.Tensor: shape=(3,), dtype=float32,
numpy=array([23., 0., 27.], dtype=float32)>),
('SibSp', <tf.Tensor: shape=(3,), dtype=int32,
numpy=array([0, 0, 0], dtype=int32)>),
('Parch', <tf.Tensor: shape=(3,), dtype=int32,
numpy=array([0, 0, 0], dtype=int32)>),
('Ticket', <tf.Tensor: shape=(3,), dtype=string,
numpy=array([b'SC/PARIS 2133', b'14311', b'347089'], dtype=object)>),
('Fare', <tf.Tensor: shape=(3,), dtype=float32,
numpy=array([15.0458, 7.75 , 6.975 ], dtype=float32)>),
('Cabin', <tf.Tensor: shape=(3,), dtype=string,
numpy=array([b'', b'', b''], dtype=object)>),
('Embarked', <tf.Tensor: shape=(3,), dtype=string,
numpy=array([b'C', b'Q', b'S'], dtype=object)>)
])
tf.Tensor([0 1 1], shape=(3,), dtype=int32)
tf.data.TextLineDataset
)ds5 = tf.data.TextLineDataset(
filenames = ["./data/titanic/train.csv",
"./data/titanic/test.csv"]).skip(1) #略去第一行header
for line in ds5.take(1):
print(line)
tf.Tensor(b'493,0,1,"Molson, Mr. Harry Markland",
male,55.0,0,0,113787,30.5,C30,S', shape=(), dtype=string)
tf.data.Dataset.list_files
)ds6 = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg")
for file in ds6.take(2):
print(file)
tf.Tensor(b'./data/cifar2/train/airplane/4266.jpg',
shape=(), dtype=string)
tf.Tensor(b'./data/cifar2/train/airplane/4131.jpg',
shape=(), dtype=string)
参考链接:https://github.com/lyhue1991/eat_tensorflow2_in_30_days