深度学习2.0-12.神经网络与全连接层之数据集的加载

文章目录

      • 数据集的加载
        • 1.tf.data.Dataset.from_tensor_slices
        • 2.keras.datasets
        • 3.tf.data.Dataset.from_tensor_slices
          • 1.shuffle-打散-注意:x和y的相对顺序不能打散
          • 2.map-可用于数据预处理
          • 3.batch-读取batch个(x,y)
          • 4.repeat-重复取

数据集的加载

深度学习2.0-12.神经网络与全连接层之数据集的加载_第1张图片

1.tf.data.Dataset.from_tensor_slices

它的作用是切分传入Tensor的第一个维度,生成相应的dataset。

将输入的张量的第一个维度看做样本的个数,沿其第一个维度将tensor切片,得到的每个切片是一个样本数据。实现了输入张量的自动切片。

可以是numpy格式,也可以是tensorflow的tensor的格式,函数会自动将numpy格式转为tensorflow
的tensor格式

输入可以是一个tensor 或 一个tensor字典(字典的每个key对应的value是一个tensor,要求各tensor的
第一个维度相等) 或 一个tensor tupletuple 的每个元素是一个tensor,要求各tensor的第一个维度
相等)
# from_tensor_slices 为输入张量的每一行创建一个带有单独元素的数据集
ts = tf.constant([[1, 2], [3, 4]])
ds = tf.data.Dataset.from_tensor_slices(ts)   # [1, 2], [3, 4]

2.keras.datasets

深度学习2.0-12.神经网络与全连接层之数据集的加载_第2张图片
mnist数据加载
深度学习2.0-12.神经网络与全连接层之数据集的加载_第3张图片
CIFAR10/100数据集的加载
深度学习2.0-12.神经网络与全连接层之数据集的加载_第4张图片

3.tf.data.Dataset.from_tensor_slices

深度学习2.0-12.神经网络与全连接层之数据集的加载_第5张图片

1.shuffle-打散-注意:x和y的相对顺序不能打散

可以利用idx来记录打散顺序,以确保x和y的相对顺序
深度学习2.0-12.神经网络与全连接层之数据集的加载_第6张图片

2.map-可用于数据预处理

深度学习2.0-12.神经网络与全连接层之数据集的加载_第7张图片

3.batch-读取batch个(x,y)

深度学习2.0-12.神经网络与全连接层之数据集的加载_第8张图片

4.repeat-重复取

深度学习2.0-12.神经网络与全连接层之数据集的加载_第9张图片

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets

# 数据预处理
def prepare_mnist_features_and_labels(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255.0
    y = tf.cast(y,dtype=tf.int64)
    y = tf.one_hot(y,depth=10)
    return x,y

# 数据集
# 数据集加载-->dataset-->ont_hot等数据预处理-->shuffle-->batchm
def mnist_dataset():
    (x,y),(x_val,y_val) = datasets.fashion_mnist.load_data()

    ds = tf.data.Dataset.from_tensor_slices((x,y))
    ds = ds.map(prepare_mnist_features_and_labels)
    ds = ds.shuffle(60000).batch(100)

    ds_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(10000).batch(100)

    return ds,ds_val

if __name__ == '__main__':
    ds,ds_val = mnist_dataset()

你可能感兴趣的:(深度学习2.0基础,tensorflow,深度学习,神经网络)