对于一些小型常用的数据集,TensorFlow有相关的API可以调用——keras.datasets
经典数据集:
1、boston housing - 波士顿房价
2、mnist/fasion mnist - 手写数字集/时髦品集
3、cifar10/100 - 物象分类
4、imdb - 电影评价
使用 tf.data.Dataset 的好处:
1.既能让后面有迭代的方式,又能直接对数据(tensor类型)进行预处理,还能支持batch和多线程的方式处理
2.提供了 .shuffle(打散), .map(预处理) 功能
(x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
用于打乱数据集但不影响映射关系
from_tensor_slices的方法见这篇博客
db = tf.data.Dataset.from_tensor_slices( (x_test, y_test) )
db = db.shuffle(10000) # x_test,y_test映射关系不变
用于使用预处理映射
tf.cast的用法见这篇博客
tf.one_hot的用法见这篇博客
def preprocess(x,y):
# 定义一个与处理函数 用于将numpy数据类型转化为Tensor的类型(dtype=float32)
x = tf.cast(x, dtype=tf.float32) / 255 # 将灰度级归一化
y = tf.cast(y, dtype=tf.float32)
y = tf.one_hot(y, depth=10) # 对数字编码 y 进行one_hot编码,10个0-1序列中只有一个1
return x, y
db2 = db.map(preprocess)
res = next(iter(db2)) # iter(db2):取得db2的迭代器,next(iter(db2)):迭代
批处理
db3 = db2.batch(32) # (32张图片,32个label)为一个batch
res = next(iter(db3)) # 进行迭代
res[0].shape, res[1].shape # 分别是一个batch中图片格式与label格式的shape
(TensorShape([32,32,32,3]), TensorShape([32,1,10]))
# 图片格式是(32张,32*32大小,3个通道) # (32张图片对应的label,1个label——通常会squeeze掉,10个one_hot深度)
整个数据集的循环次数
db4 = db3.repeat() # 这样就是一直repeat迭代,死循环
db4 = db3.repeat(2) # 这个是迭代2次
def prepare_mnist_features_and_labels(x,y):
x = tf.cast(x, tf.float32) / 255.0
y = tf.cast(y, tf.float64)
return x,y
def mnist_dataset():
(x, y),(x_val, y_val) = datasets.fashion_mnist.load_data() # 1.加载图像数据和通用数据(val指的是validation,测试数据集)
y = tf.one_hot(y, depth=10) # 2.数据 one_hot编码
y_val = tf.one_hot(y_val, depth=10) # label one_hot编码
ds = tf.data.Dataset.from_tensor_slices((x, y)) # 3.转换为Dataset类型
ds = ds.map(prepare_mnist_features_and_labels) # 4.预处理函数映射
ds = ds.shuffle(60000).batch(100) # 5.其他处理——如本处的前60000个打乱,100个为一个批次
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(prepare_mnist_features_and_labels)
ds_val = ds_val.shuffle(10000).batch(100)
return ds, ds_val