上面已经说过了怎么使用tf.data处理简单的数据输入,有了上面的基础之后,这部分使用tf.data来创建更加复杂更加贴近于现实的数据输入. 这里主要使用tfrecords来创建输入流。之后训练模型非常方便,要是想通过其他的方式进行输入操作的,可以参考官方文档。
这一节可以看做是TensorFlow学习(十五):使用tf.data来创建输入流(上)和TensorFlow学习(十一):保存TFRecord文件 这两节的后续。
要是对于怎么生成tfrecords不熟悉的话,可以参考这两节来复习。
这里给出了一些常见的使用案例,代码存放在:LearningTensorFlow/11.TFRecord/
还是老样子,这里先把最主要的API列在这里,后面会用到这些API,先混个脸熟.
__init__
(filenames,compression_type=None,buffer_size=None)
创建一个TFRecordDataset
参数:`
tf.string
类型的tensor里面包含一个或者多个TFRecord文件的文件名""
(没有压缩), "ZLIB"
, 或者 "GZIP"
.map(map_func,num_parallel_calls=None)
作用:在这整个dataset里面使用map_func
来映射,实际上我们用的时候,可以通过这个函数来装换为一般的dataset.也就是返回一个Dataset对象.
参数:
apply(transformation_func)
在dataset上面应用一个转换函数。
dataset = (dataset.map(lambda x: x ** 2)
.apply(group_by_window(key_func, reduce_func, window_size))
.map(lambda x: x ** 3))
参数:
transformation_func: 接受一个Dataset
作为参数并且返回另外一个Dataset
的函数
当然这里还有一些batch(),shuffle()等等函数,这里就不讲了,上面一节有,这里的用法和上面一节是一样的。后面的例子可以清楚的看到。
tf.parse_single_example(serialized,features,name=None,example_names=None)
作用:解析读入的单个Example proto.
Args:
serialized: 单个的序列化的Example
.
features: A dict mapping feature keys to FixedLenFeature or VarLenFeature values.
name: A name for this operation (optional).
example_names: (Optional) A scalar string Tensor, the associated name. See _parse_single_example_raw documentation for more details.
Returns:
A dict mapping feature keys to Tensor and SparseTensor values.
Raises:
ValueError: if any feature is invalid.
这里的操作是TensorFlow学习(十一):保存TFRecord文件 把.csv
文件转为tfrecord文件的读取操作.
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
#tfrecord 文件列表
file_list=["train.tfrecords"]
#创建dataset对象
dataset=tf.data.TFRecordDataset(filenames=file_list)
#定义解析和预处理函数
def _parse_data(example_proto):
parsed_features=tf.parse_single_example(
serialized=example_proto,
features={
"image_raw":tf.FixedLenFeature(shape=(),dtype=tf.string),
"label":tf.FixedLenFeature(shape=(),dtype=tf.int64)
}
)
# get single feature
raw = parsed_features["image_raw"]
label = parsed_features["label"]
# decode raw
image = tf.decode_raw(bytes=raw, out_type=tf.int64)
image=tf.reshape(tensor=image,shape=[28,28])
return image,label
#使用map处理得到新的dataset
dataset=dataset.map(map_func=_parse_data)
#使用batch_size为32生成mini-batch
#dataset = dataset.batch(32)
#创建迭代器
iterator=dataset.make_one_shot_iterator()
next_element=iterator.get_next()
with tf.Session() as sess:
for i in range(10):
image, label = sess.run(next_element)
print(label)
print(image.shape)
print(label.shape)
#plt.imshow(image)
#plt.show()