Tensorflow:实战Google深度学习框架第七章

一介绍

本章节主要介绍数据预处理的相关操作,包括统一输入数据的格式,图像数据的预处理,多线程的数据处理方式,以及Dataset数据集API。

二 内容

2.1 TFRecord输入数据格式

tensorflow提供一套统一的格式来存储数据,这个格式就是TFRecord, 该格式都是通过tf.train.Example Protocol Buffer的格式存储。

Example例子

Example格式

转换成Tfrecord格式
解析TFRecord

2.2 多线程输入数据处理框架

经典输入数据处理流程图

2.2.1 队列与多线程

在tensorflow中,队列和变量类似,都是计算图上有状态的节点。

FIFOQueue

上面提到的FIFOQueue是先进先出队列,在数据集中一般需要对数据进行随机,所以tensorflow提供了RandomShuffleQueue队列, 每次会将队列中的元素打乱,每次出队列操作得到的是从当前队列中所有元素中随机选择一个。                                                                                          在tensorflow中,队列不仅仅是一种数据结构,还是异步计算张量取值的一个重要机制。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。tf.Coordinator和tf.QueueRunner俩个类来完成多线程的协同的功能。tf.Coordinator主要用于协同多个线程一起停止,并提供了should_stop,request_stop和join三个函数。在启动之前需要声明一个tf.Coordinator类,并将这个类传入每一个创建的线程中。启动的线程需要一起查询tf.Coordinator类中提供的should_stop函数,当这个函数返回true时,则当前的线程也需要退出,每一个启动的线程都可以通过request_stop函数来通知其他线程退出。当某一个线程调用request_stop函数之后,should_stop函数的返回值被设置成True,这样其他线程可以同步终止。

多线程

tf.QueueRunner主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过tf.Coordinator类来统一管理。

多线程操作队列

2.2.2 输入文件队列

虽然一个TFRecord文件中可以存储多个训练样例, 但是当训练数据量较大时,可以将数据分成多个TFRecord文件来提高处理效率。Tensorflow提供了tf.train.match_filenames_once函数来获取符合一个正则表达式的所有文件,得到的文件列表可以通过tf.train.string_input_producer函数进行有效的管理。tf.train.string_input_producer函数会使用初始化时提供的文件列表创建一个输入队列,输入队列中原始元素为文件列表中的所有文件。当一个输入队列中的所有文件都被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列,该函数可以设置num_epochs参数来限制加载初始文件列表的最大轮数。超过轮数报错。

写入测试数据
多文件读取

2.2.3 组合训练数据

已经介绍了如何从文件列表中读取单个样例,在得到多个样例后,将其组合成batch,提高模型的运行效率。tensorflow提供了tf.train.batch和tf.train.shuffle_batch函数将单个样例组织成batch的形式输出。

tf.train.batch

tf.train.shuffle_batch跟tf.train.batch是几乎相同的,只是多了一个min_after_dequeue,该参数限制了出队队列中元素的最少个数,当队列元素太少时,随机打乱样例顺序的作用不大。

输入数据处理框架

2.3 DataSet

上一节通过队列进行多线程输入的方法。除队列以外,tensorflow还提供了一套更高层的数据处理框架。在新的框架中,每一个数据来源都被抽象成一个“数据集”,开发者可以根据数据集为基本对象,方便进行batching,随机打乱(shuffle)等操作。                                                                在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord文件,一个文本文件,或着经过sharding的一系列文件,等等。由于训练数据通常无法全部写入内存中, 从数据集中读取数据时需要使用一个迭代器按顺序进行读取。

简单的Dataset

利用数据集读取数据三个基本步骤:                                                                                                   1.定义数据集的构造方法,上面例子中的tf.data.Dataset.from_tensor_slices                               2.定义遍历器,make_one_shot_iterator                                                                                       3.使用get_next()方法从遍历器中读取数据集。

从文件中构建数据集

数据大部分是TFRecord文件, 但是这种格式的文件还需要解析一遍,所以:

读取TFRecord

使用make_one_shot_iterator时,数据集的所有参数必须已经确定,因make_one_shot_iterator不需要特别的初始化过程。如果需要用到placeholder来初始化数据集,那就需要用到initializable_iterator。

整体流程

你可能感兴趣的:(Tensorflow:实战Google深度学习框架第七章)