tensorflow2的数据加载

对于一些小型常用的数据集,TensorFlow有相关的API可以调用——keras.datasets

经典数据集:

1、boston housing - 波士顿房价

2、mnist/fasion mnist - 手写数字集/时髦品集

3、cifar10/100 - 物象分类

4、imdb - 电影评价

使用 tf.data.Dataset 的好处:

1.既能让后面有迭代的方式,又能直接对数据(tensor类型)进行预处理,还能支持batch和多线程的方式处理

2.提供了 .shuffle(打散), .map(预处理) 功能

导入fashion_mnist数据

(x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()

.shuffle

用于打乱数据集但不影响映射关系

from_tensor_slices的方法见这篇博客

db = tf.data.Dataset.from_tensor_slices( (x_test, y_test) )
db = db.shuffle(10000)    # x_test,y_test映射关系不变

.map

用于使用预处理映射

tf.cast的用法见这篇博客

tf.one_hot的用法见这篇博客

def preprocess(x,y):
	# 定义一个与处理函数 用于将numpy数据类型转化为Tensor的类型(dtype=float32)
	x = tf.cast(x, dtype=tf.float32) / 255    # 将灰度级归一化
	y = tf.cast(y, dtype=tf.float32)
	y = tf.one_hot(y, depth=10)       		  # 对数字编码 y 进行one_hot编码,10个0-1序列中只有一个1
	return x, y

db2 = db.map(preprocess)

res = next(iter(db2))   # iter(db2):取得db2的迭代器,next(iter(db2)):迭代

.batch

批处理

db3 = db2.batch(32)   # (32张图片,32个label)为一个batch

res = next(iter(db3))  # 进行迭代

res[0].shape, res[1].shape   # 分别是一个batch中图片格式与label格式的shape
(TensorShape([32,32,32,3]), TensorShape([32,1,10]))
# 图片格式是(32张,32*32大小,3个通道)  # (32张图片对应的label,1个label——通常会squeeze掉,10个one_hot深度)

.repeat

整个数据集的循环次数

db4 = db3.repeat()   # 这样就是一直repeat迭代,死循环
db4 = db3.repeat(2)  # 这个是迭代2次

总结:整体的一个示例:

def prepare_mnist_features_and_labels(x,y):
    x = tf.cast(x, tf.float32) / 255.0
    y = tf.cast(y, tf.float64)
    return x,y

def mnist_dataset():
    (x, y),(x_val, y_val) = datasets.fashion_mnist.load_data()  # 1.加载图像数据和通用数据(val指的是validation,测试数据集)
    y = tf.one_hot(y, depth=10)                                 # 2.数据  one_hot编码
    y_val = tf.one_hot(y_val, depth=10)                         #   label one_hot编码

    ds = tf.data.Dataset.from_tensor_slices((x, y))             # 3.转换为Dataset类型
    ds = ds.map(prepare_mnist_features_and_labels)              # 4.预处理函数映射
    ds = ds.shuffle(60000).batch(100)                           # 5.其他处理——如本处的前60000个打乱,100个为一个批次
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(10000).batch(100)
    return ds, ds_val

你可能感兴趣的:(TensorFlow2.×)