tensorflow2.0数据加载

前言

这个是我学习网易云课堂上对于tensorflow2.0数据加载内容的笔记

小型的经典数据通过keras.datasets就可以使用

keras.datasets可以加载的数据集

  • boston houseing
  • mnist/fashion mnist
  • cifar10/100
  • imdb

tf.data.Dataset.from_tensor_slices

切分传入的 Tensor 的第一个维度,生成相应的 dataset。可以用迭代器来取出dataset

a = tf.random.normal([28,28,28])
aa = tf.data.Dataset.from_tensor_slices(a)
print(next(iter(aa)).shape)

b = tf.random.normal([25,27,28,28])
bb = tf.data.Dataset.from_tensor_slices(b)
print(next(iter(bb)).shape)

c = tf.random.normal([28,2])
cc = tf.data.Dataset.from_tensor_slices(c)
print(next(iter(cc)).shape)

##########输出#################
(28, 28)
(27, 28, 28)
(2,)

如果是tf.data.Dataset.from_tensor_slices((x,y)),那么返回的会是元组,这样可以做到图片与label相对应(假设x是传入的图像,y是label),可以用next(iter(db))[0].shape来查看形状

db = tf.data.Dataset.from_tensor_slices((x,y))
print(next(iter(db))[0].shape)
print(next(iter(db))[1].shape)
###########输出#########
(28, 28)
()
shuffle

原理可以看这个嘞大佬的博客
打乱数据顺序,可以防止过拟合这样子,在训练数据时用。参数就给一个比较大的就好了。
buffer_size = 1 数据集不会被打乱

buffer_size = 数据集样本数量,随机打乱整个数据集

buffer_size > 数据 集样本数量,随机打乱整个数据集

db = db.shuffle(buffer_size )
map

对数据进行预处理

def preproess(x,y):
#tf.cast()函数的作用是执行 tensorflow 中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。
    x = tf.cast(x,dtypt=float32) / 255.
    y = tf.cast(y,dtype=int32)
    y = tf.one_hot(y,depth=10)
    return x,y
db2 = db.map(preproess)

batch

#用于迭代器每次取出多少图
db3 = db2.batch(32)
res = next(iter(db3))

repeat

就是重复操作,防止用whie True时出错

#一直迭代
db4 = db3.repeat()
#迭代全部的数据4次退出
db4 = db3.repeat(4)

######暂时就这些了#######19.7.18

你可能感兴趣的:(网易云课堂,深度学习)