本文涉及TensorFlow的三个组件:
它们协同完成了TensorFlow的多线程数据输入。
--------------------------------------------------------------------------------------------
例1:
import tensorflow as tf
with tf.Session() as sess:
q = tf.FIFOQueue(3, 'float')
init = q.enqueue_many(([0.1, 0.2, 0.3],))
init2 = q.dequeue()
init3 = q.enqueue(4.)
sess.run(init)
sess.run(init2)
sess.run(init3)
quelen = sess.run(q.size())
for i in range(quelen):
print(sess.run(q.dequeue()))
新方法 | 描述 |
tf.FIFOQueue(3,'float') | 创建一个长度为3的队列,数据类型为'float' |
q.enqueue(vals,name=None) | 将一个元素编入该队列,如果在执行该操作时队列已满,那么将会阻塞直到幸免于难编入队列之中 |
q.enqueue_many(vals,name=None) | 将零个或多个元素编入队列之中 |
q.dequeue(name=None) | 把元素从队列中移出,如果在执行该操作时队列已空,那么将会阻塞直到元素出列,返回出列的tensors的tuple |
q.dequeue_many(n,name=None) | 将一个或多个元素从队列中移出 注:本例中未使用到这个方法,但聪明的你应该已经联想到了。 |
过程: [0.1,0.2,0.3]入列->0.1出列->4.0入列;
for : 0.2出列并打印->0.3出列并打印->4.0出列并打印
效果:
0.2
0.3
4.0
Process finished with exit code 0
本例存在一个问题:队列的操作是在主线程的对话中依次完成的,这样的操作会造成数据的读取和输入较慢,处理相对困难。因此我们引入TF的队列管理器QueueRunner:
例2:
import tensorflow as tf
with tf.Session() as sess:
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
qr = tf.train.QueueRunner(q, enqueue_ops= [add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess, start = True)
for i in range(10):
print(sess.run(q.dequeue()))
新方法 | 描述 |
qr=tf.train.QueuRunner(q,enqueue_ops=[...]*2) | 创建队列管理器:指定管理的是队列q;指定入列操作;用两个线程去完成此项任务 |
qr.create_threads(sess,start=True) | 启动线程,并指定会话(TF所有的操作都是在会话中执行的) |
运行效果:
0.0
0.0
6.0
12.0
25.0
29.0
61.0
68.0
74.0
79.0
ERROR:tensorflow:Exception in QueueRunner: Session has been closed.
ERROR:tensorflow:Exception in QueueRunner: Session has been closed.
ERROR:tensorflow:Exception in QueueRunner: Session has been closed.
ERROR:tensorflow:Exception in QueueRunner: Session has been closed.
发现打印了10次输出后开始报错了。原因是主线程执行了10次出列操作,然后随着with上下文的退出而关闭了会话,而入列线程还在继续!我们改一下:
import tensorflow as tf
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
enqueue_threads = qr.create_threads(sess, start=True)
for i in range(10):
print(sess.run(q.dequeue()))
只是去掉了with上下文管理,主线程在执行完10次后没有关闭会话。再次运行,可以看到此时的会话并没有看错,但是程序退出没有结束,而是被挂起。造成这种情况的原因是add操作和入队操作没有同步,即TensorFlow在队列设计时为优化IO系统,队列的操作一般使用批处理,这样入队线程没有发送结束的信息而程序主线程期望将程序结束,因此造成程序堵塞程序被挂起。
例4:
import tensorflow as tf
q = tf.FIFOQueue(1000, 'float32')
counter = tf.Variable(0.0)
add_op = tf.assign_add(counter, tf.constant(1.0))
enqueueData_op = q.enqueue(counter)
sess = tf.Session()
qr = tf.train.QueueRunner(q, enqueue_ops=[add_op, enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
for i in range(10):
print(sess.run(q.dequeue()))
coord.request_stop()
coord.join(enqueue_threads)
新方法 | 描述 |
tf.train.Coordinator() | 创建线程协调器 |
coord.request_stop() | 主线程请求结束 |
coord.join(enqueue_threads) | 等待线程结束 |
create_threads函数被添加了一个新的参数:线程协调器,用于协调线程之间的关系,之后启动线程以后,线程协调器在最后负责对所有线程的接受和处理,当一个纯种结束时,纯种协调器会对所有线程发出通知,协调其完毕。