【Tensorflow】多线程协同

一、tf.train.Coordinator

tf.train.Coordinator主要用于协同多个线程一起停止。
import tensorflow as tf
import numpy as np
import threading
import time

def Loop(coord, worker_id):
    while not coord.should_stop():
        if np.random.rand() < 0.1:
            print("Stoping from id: %d\n" % (worker_id))
            coord.request_stop()
        else:
            print("Working on id: %d\n" % (worker_id))
        time.sleep(1)

coord = tf.train.Coordinator()
threads = [threading.Thread(target=Loop, args=(coord, i, )) for i in range(3)]
for t in threads: t.start()
coord.join(threads)
Working on id: 0
Working on id: 1
Working on id: 2
Working on id: 0
Working on id: 1
Working on id: 2
Working on id: 1
Working on id: 2
Working on id: 0
Working on id: 1
Working on id: 2
Working on id: 0
Working on id: 1
Working on id: 0
Working on id: 2
Working on id: 2
Working on id: 0
Working on id: 1
Working on id: 0
Working on id: 1
Working on id: 2
Working on id: 2
Working on id: 1
Working on id: 0
Working on id: 0
Stoping from id: 1

二、tf.train.QueueRunner

tf.train.QueueRunner主要用于启动多个线程来操作同一队列。
import tensorflow as tf

queue = tf.FIFOQueue(100, "float")
enqueue_op = queue.enqueue([tf.random_normal([1])])
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
tf.train.add_queue_runner(qr)
out_tensor = queue.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for _ in range(3):
        print(sess.run(out_tensor)[0])
    coord.request_stop()
    coord.join(threads)

你可能感兴趣的:(tensorflow)