目录
一、DataSet的创建:
二、DataSet的常用函数:
三、使用DataSet改写fashion_MNIST分类模型:
类似于numpy中的ndarray数据类型和数据操作,TensorFlow提供了tf.data.DataSet模块,方便地处理数据输入、输出,支持大量的数据计算和转换,tf.data.DataSet中是一个或者多个tensor对象。
直接从tensor创建tf.data.DataSet,使用tf.data.DataSet.from_tensor_slices()函数,函数参数可以是python自带数据类型list,或者numpy.ndarray:
# 可以从list,从numpy.ndarray创建 dataset
X= np.array([1.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])
Y= np.array([[1.3,4.4],[5.5,6.71]])
dataset1=tf.data.Dataset.from_tensor_slices([1,2,3,4]) # list 创建
dataset2=tf.data.Dataset.from_tensor_slices(X) # numpy 创建
dataset2
dataset的类型是tensorslicedataset,可以使用循环查看每个元素都是一个tensor,也可以用numpy方法;
for i in dataset2.take(2):
print(i)
print(i.numpy())
1、在建模之前可以对数据进行处理,比如:
① shuffle()函数,提供乱序操作;
② repeat()函数,提供数据重复操作;
③ batch() 函数,提供批量读取功能;
X= np.array([1.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])
dataset2=tf.data.Dataset.from_tensor_slices(X) # numpy 创建
data_shuffle=dataset2.shuffle(3) # 打乱数据
data_repeat=dataset2.repeat(count=2) # 数据重复
data_batch=dataset2.batch(2) # 数据批量读取
2、数据变换,包括map函数
dataset_sq=dataset2.map(tf.square)
与之前处理方式不同在于,建模之前对数据进行了一些变换,并且增加了模型训练过程中的验证数据;
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
%matplotlib inline
(train_image,train_lable),(test_image,test_lable)=tf.keras.datasets.fashion_mnist.load_data()
plt.imshow(train_image[11]) # image show
ds_train_image=tf.data.Dataset.from_tensor_slices(train_image) # 加载数据
ds_train_lable=tf.data.Dataset.from_tensor_slices(train_lable) # 加载数据
# 打乱数据,无线重复,成批读取
da_train=tf.data.Dataset.zip((ds_train_image,ds_train_lable)).shuffle(10000).repeat().batch(64)
# 测试数据集
ds_test_image=tf.data.Dataset.from_tensor_slices(test_image) # 加载数据
ds_test_lable=tf.data.Dataset.from_tensor_slices(test_lable) # 加载数据
ds_test=tf.data.Dataset.zip((ds_test_image,ds_test_lable)).batch(64)
model=tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128,activation="relu"),
tf.keras.layers.Dense(10,activation="softmax")
])
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["acc"])
train_image.shape[0]//64
history=model.fit(da_train,
epochs=5,
steps_per_epoch=train_image.shape[0]//64,
validation_data=ds_test,
validation_steps=test_image.shape[0]//64
)
model.evaluate(test_image,test_lable)
plt.plot(history.epoch,history.history.get('loss'))
plt.plot(history.epoch,history.history.get('acc'))