tensorflow编程一些需要知道的 - 3


做训练时,我们往往要处理大批量的数据,这时如果有个可以异步读取的方式,那么处理程序会更加灵活和高效。FIFOQueue 、RandomShuffleQueue 便是tensorflow提供的一些通过队列来做异步数据存取的方法,并且它是多线程的(tf.Session对象就是多线程的)。这个框架如下图

tensorflow编程一些需要知道的 - 3_第1张图片



由于整个程序是多线程的,因此我们可以在同一个session里并行跑多个ops。为此,tf一共了tf.Coordinator 和 tf.QueueRunner来帮助我们编写这样的程序,使得我们可以方便地处理多个线程的启动、停止、异常捕获等。Coordinator类让我们的多个线程同时停止、把异常抛给调用的地方。QueueRunner类针对enqueue ops创建多个线程,而这些线程可以通过Coordinator类来同时停止,并当有异常发生时通过一个关闭线程来自动关闭这些线程。下面是一个对上述的示例

import tensorflow as tf

example = ...ops to create one example...
#Step 1. 为输入构建一个queue,并将这些输入enqueue进去
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)

#Step 2. 从queue里dequeue样本进行训练 
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...

#Note: 以上可以通过tf.train.string_input_producer来完成

#Step 3. 通过4个线程来并发enqueue_op样本
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
#构建Coordinator, 启动QueueRunner
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# 跑训练迭代。这些线程除了跑enqueue_op, 而且捕获并处理异常, 通过coordinator来在主循环中处理。
try:
    for step in xrange(1000000):
        if coord.should_stop():
            break
        sess.run(train_op)
except Exception, e:
    # 将异常抛给coordinator,通知线程停止
    coord.request_stop(e)
finally:
    coord.request_stop()
    coord.join(enqueue_threads)





你可能感兴趣的:(深度学习)