本文代码来自王晓华著的深度学习与计算机视觉实战系列
Tensorflow队列
队列(queue)是一种最为常见的数据输入输出方式,它提供了一个先进先出的线性数据结构(如同人们排队一样),一端只负责增加队列中的数据元素,另一端则负责数据的输出和删除工作。通常我们将可以增加数据元素的队列成为队尾,而输出和删除的一端称为队首。
与python类似,tensorflow同样应用队列作为数据的一种基本输入输出方式,新数据自动插入队尾,队首的数据自动输出并删除。在tensorflow中队列处于一种有状态节点的地位,随着其他节点在图中状态的改变,队列中的这个‘节点’状态也随之改变。
tensorflow中队列的使用函数
操作 | 描述 |
---|---|
tf.dequeue(name = None) | 将元素从队列中移除。如果在执行该操作时队列已空,那么将会阻塞直到元素出列,返回出列的tensors的tuple |
tf.enqueue_many(vals, name = None) | 将零个或多个元素编入该队列中 |
class tf.QueueBase | 基本的队列应用类,队列(queue)是一种数据结构,该结构通过对个步骤纯粹tensors,并且对tensors进行入列(enqueue)与出列(dequeue)操作 |
tf.enqueue(vals, name = None | 将一个元素加入该队列。如果在执行该操作时队列已满,那么将会阻塞知道元素编入队列之中 |
tf.dequeue_many(n, name = None) | 将一个或多个元素从队列移除 |
tf.size(name = None) | 计算队列中的元素个数 |
tf.close | 关闭该队列 |
tf.dequeue_up_to(n, name = None) | 从该队列中移除n个元素并将之连接 |
tf.dtypes | 列出组成元素的数据类型 |
tf.from_list(index, queues) | 根据queues[index]的参考队列创建一个队列 |
tf.name | 返回队列最下面元素的名称 |
tf.names | 返回队列每个组成部分的名称 |
class tf.FIFOQueue | 在出列时依照先入先出顺序 |
class tf.PaddingFIFOQueue | 一个FIFOQueue,同时根据padding支持batching变长的tensor |
class tf.RandomShuffleQueue | 将队列随机元素列出 |
示例1
import tensorflow as tf
with tf.Session() as sess:
q = tf.FIFOQeue(5, 'float')#创建一个先入先出的数列,其中数据为5个,类型为浮点型
init = q.enqueue_many(([1.0, 2.0, 3.0, 4.0, 5.0],))#填充数列,注意最后的逗号不能省略
init2 = q.dequeue()#删除第一个数字
init3 = q.enqueue(3.5)#在最后一行添加数字3.5
#tensorflow中任何操作都是在‘会话’中进行的,因此上述所有的操作实际上并未执行,需要添加如下会话
sess.run(init)
sess.run(init2)
sess.run(init3)
quelen = sess.run(q.size())#size函数显示了队列中的元素数量
for i in range(quelen):
print(sess.run(q.dequeue()))#依次将队列中的元素打印出来
2.0
3.0
4.0
5.0
3.5
从结果不难发现,原来队列中的第一个元素被dequeue函数移除,并在队列最后添加了元素3.5.
示例一的例子表明了队列在tensorflow中是如何实现的,但是这样的操作会造成数据的读取和输入较慢,试想如果我们要录入大量的数据,必然不能这样操作。
tensorflow中提供了QueueRunner函数用来解决异步操作问题,它可以创建一系列的线程同时进入主线程内进行操作,数据的读取和操作时同步,即主线程在进行训练模型的工作的同时将数据从硬盘读入。
示例2
import tensorflow as tf
with tf.Session() as sess:
#创建一个先进先出的队列,队列包含1000个元素,采用浮点类型
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)#设置参数变量,初始值设为0.0,随着数据输入的进行,参数随时变化
add_op = tf.assign_add(counter, tf.constant(1.0))#实现一个自增加操作,每次counter数值+1
enqueueData_op = q.enqueue(counter)#新数据入列
#定义队列管理器op。指定多少个子线程,子线程该干什么事情(队列操作)
#这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops = [add_op, enqueueData_op]*2)
sess.run(tf.global_variables_initializer())#变量执行初始化操作
enqueue_threads = qr.create_threads(sess, start = True) #启动入队线程,start = True代表立刻开始
for i in range(5):
print(sess.run(q.dequeue())#执行十个循环之后关闭会话(这是程序报错的原因)
上述程序首先创建了一个数据处理函数,add_op的操作时将整数1叠加到变量counter里面。为了执行这个操作,qr创建了一个队列管理器,调用多线程去完成此任务。create_threads函数用于启动线程。
6.0
24.0
44.0
54.0
60.0
E1013 17:36:15.699034 16972 queue_runner_impl.py:275] Exception in QueueRunner: Session has been closed.
E1013 17:36:15.705018 19704 queue_runner_impl.py:275] Exception in QueueRunner: Session has been closed.
可以看到在前五个循环程序可以正常运行,之后的错误提示为:队列管理器企图关闭会话,循环结束。
我们可以换一种代码形式:
示例3
import tensorflow as tf
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)#设置参数变量,初始值设为0.0,随着数据输入的进行,参数随时变化
add_op = tf.assign_add(counter, tf.constant(1.0))#实现一个自增加操作,每次counter数值+1
enqueueData_op = q.enqueue(counter)#新数据入列
sess = tf.Session()#执行会话的内容由第一句调整到这里,这样就会让前面的操作不同步
qr = tf.train.QueueRunner(q, enqueue_ops = [add_op, enqueueData_op]*2)
sess.run(tf.global_variables_initializer())#变量执行初始化操作
enqueue_threads = qr.create_threads(sess, start = True)
for i in range(5):
print(sess.run(q.dequeue()))
此时再次运行代码,系统不再报错。这是因为程序循环完成后没有结束,而是被挂起。
注意:Tensorflow中一般遇到程序挂起的情况指的是数据输入与预处理没有同步,即需要数据时却没有数据被输入到队列中,这样线程就会被整体挂起。此时tf不会报错而是会处于等待状态。
示例4
从上面例子可以看出,Tensorflow会话支持多线程操作,多个线程可以很方便的在一个会话下共同工作,并行的相互执行。但是通过程序演示可以发现,这种同步会造成某个线程想要关闭对话时,对话被强行关闭,未完成的工作的线程也被强行关闭。
Tensorflow为了解决多线程问题,提供了Coordinator和QueueRunner函数对线程进行控制和协调。在使用上,这两个函数类必须同时工作,共同协作来停止会话中所有线程,并想在等待中的所有工作线程终止的程序报告。
import tensorflow as tf
#with tf.Session() as sess:#标注1
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
sess = tf.Session()#标注2,此处添加标注
qr = tf.train.QueueRunner(q, enqueue_ops = [add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess, start = True) #启动入队线程
coord = tf.train.Coordinator()#标注3
enqueue_threads = qr.create_threads(sess, coord = coord, start = True)#标注4
for i in range(5):
print(sess.run(q.dequeue()))
coord.request_stop()#标注5
coord.join(enqueue_threads)#标注6
上述代码中QueueRunner称作队列管理器,Coordinator成为线程协调器。此处做一个思考:
1,当我们把1处的标注去除,在2处添加标注,其余不变时,结果依然会在返回5个元素后显示线程关闭错误,为何?
2,维持1处的标注不变,将3,5,6添加标注,将4处的coord=coord函数去掉,结果和现在一样,为何?
下一次我们会详细讲解队列的读取问题,针对以上两个问题,欢迎大家讨论~~~