____tz_zs
本文中包含的方法有:
tf.train.string_input_producer
tf.train.input_producer
tf.train.QueueRunner
tf.train.add_queue_runner
tf.train.Coordinator
tf.train.start_queue_runners
tf.train.string_input_producer 函数会使用提供的文件列表(string_tensor)创建一个输入队列
string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
string_tensor 参数 为提供的文件列表
num_epochshs 参数 用于限制加载出事文件列表的最大轮数。当为 None 时,输入队列中的所有文件都被处理后,会将初始化时提供的文件列表的文件全部重新加入队列。当设置了此参数后,会在本地计数,加载次数结束后会报 OutOfRange 错误
shuffle 参数 为True时,文件在加入队列之前会被打乱顺序,一般解决真实问题时,需要减少无关元素的影响,所以为 False
capacity 参数 为队列的容量
tf.train.string_input_producer 函数 https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer
源码位置 tensorflow/python/training/input.py
def string_input_producer(string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None):
"""Output strings (e.g. filenames) to a queue for an input pipeline.
Note: if `num_epochs` is not `None`, this function creates local counter
`epochs`. Use `local_variables_initializer()` to initialize local variables.
Args:
string_tensor: A 1-D string tensor with the strings to produce.
num_epochs: An integer (optional). If specified, `string_input_producer`
produces each string from `string_tensor` `num_epochs` times before
generating an `OutOfRange` error. If not specified,
`string_input_producer` can cycle through the strings in `string_tensor`
an unlimited number of times.
shuffle: Boolean. If true, the strings are randomly shuffled within each
epoch.
seed: An integer (optional). Seed used if shuffle == True.
capacity: An integer. Sets the queue capacity.
shared_name: (optional). If set, this queue will be shared under the given
name across multiple sessions. All sessions open to the device which has
this queue will be able to access it via the shared_name. Using this in
a distributed setting means each name will only be seen by one of the
sessions which has access to this operation.
name: A name for the operations (optional).
cancel_op: Cancel op for the queue (optional).
Returns:
A queue with the output strings. A `QueueRunner` for the Queue
is added to the current `Graph`'s `QUEUE_RUNNER` collection.
Raises:
ValueError: If the string_tensor is a null Python list. At runtime,
will fail with an assertion if string_tensor becomes a null tensor.
"""
not_null_err = "string_input_producer requires a non-null input tensor"
if not isinstance(string_tensor, ops.Tensor) and not string_tensor:
raise ValueError(not_null_err)
with ops.name_scope(name, "input_producer", [string_tensor]) as name:
string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string)
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.greater(array_ops.size(string_tensor), 0),
[not_null_err])]):
string_tensor = array_ops.identity(string_tensor)
return input_producer(
input_tensor=string_tensor,
element_shape=[],
num_epochs=num_epochs,
shuffle=shuffle,
seed=seed,
capacity=capacity,
shared_name=shared_name,
name=name,
summary_name="fraction_of_%d_full" % capacity,
cancel_op=cancel_op)
由上面 string_input_producer 的定义可知,将 string_tensor(文件地址列表) 整理成 input_tensor 后,调用了 input_producer 函数。其他参数为发生变化。
tf.train.input_producer 函数 https://www.tensorflow.org/api_docs/python/tf/train/input_producer
def input_producer(input_tensor,
element_shape=None,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
summary_name=None,
name=None,
cancel_op=None):
"""Output the rows of `input_tensor` to a queue for an input pipeline.
Note: if `num_epochs` is not `None`, this function creates local counter
`epochs`. Use `local_variables_initializer()` to initialize local variables.
Args:
input_tensor: A tensor with the rows to produce. Must be at least
one-dimensional. Must either have a fully-defined shape, or
`element_shape` must be defined.
element_shape: (Optional.) A `TensorShape` representing the shape of a
row of `input_tensor`, if it cannot be inferred.
num_epochs: (Optional.) An integer. If specified `input_producer` produces
each row of `input_tensor` `num_epochs` times before generating an
`OutOfRange` error. If not specified, `input_producer` can cycle through
the rows of `input_tensor` an unlimited number of times.
shuffle: (Optional.) A boolean. If true, the rows are randomly shuffled
within each epoch.
seed: (Optional.) An integer. The seed to use if `shuffle` is true.
capacity: (Optional.) The capacity of the queue to be used for buffering
the input.
shared_name: (Optional.) If set, this queue will be shared under the given
name across multiple sessions.
summary_name: (Optional.) If set, a scalar summary for the current queue
size will be generated, using this name as part of the tag.
name: (Optional.) A name for queue.
cancel_op: (Optional.) Cancel op for the queue
Returns:
A queue with the output rows. A `QueueRunner` for the queue is
added to the current `QUEUE_RUNNER` collection of the current
graph.
Raises:
ValueError: If the shape of the input cannot be inferred from the arguments.
"""
with ops.name_scope(name, "input_producer", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
if not element_shape.is_fully_defined():
raise ValueError("Either `input_tensor` must have a fully defined shape "
"or `element_shape` must be specified")
if shuffle:
input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
input_tensor = limit_epochs(input_tensor, num_epochs)
q = data_flow_ops.FIFOQueue(capacity=capacity,
dtypes=[input_tensor.dtype.base_dtype],
shapes=[element_shape],
shared_name=shared_name, name=name)
enq = q.enqueue_many([input_tensor])
queue_runner.add_queue_runner(
queue_runner.QueueRunner(
q, [enq], cancel_op=cancel_op))
if summary_name is not None:
summary.scalar(summary_name,
math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
return q
input_producer 函数中的参数意义与 string_input_producer 函数类似,以上 input_producer 函数的定义中,对于 element_shape 和 shuffle 的操作与文章开始参数作用的介绍一致,不再赘述。
重点在于之后的代码中
先创建了一个先进先出的队列,容量为capacity
q = data_flow_ops.FIFOQueue(capacity=capacity,
dtypes=[input_tensor.dtype.base_dtype],
shapes=[element_shape],
shared_name=shared_name, name=name)
然后,定义将 input_tensor 入队到队列的操作
enq = q.enqueue_many([input_tensor])
接下来使用 QueueRunner 创建多个线程运行 前面定义的入队操作,并使用 add_queue_runner 将 QueueRunner 加入到 tf.GraphKeys.QUEUE_RENNERS 集合中
queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
tf.train.QueueRunner 函数: https://www.tensorflow.org/api_docs/python/tf/train/QueueRunner
tf.train.QueueRunner 函数一般用来启动多个线程操作同一个队列
def __init__(self, queue=None, enqueue_ops=None, close_op=None,
cancel_op=None, queue_closed_exception_types=None,
queue_runner_def=None, import_scope=None):
"""Create a QueueRunner.
On construction the `QueueRunner` adds an op to close the queue. That op
will be run if the enqueue ops raise exceptions.
When you later call the `create_threads()` method, the `QueueRunner` will
create one thread for each op in `enqueue_ops`. Each thread will run its
enqueue op in parallel with the other threads. The enqueue ops do not have
to all be the same op, but it is expected that they all enqueue tensors in
`queue`.
Args:
queue: A `Queue`.
enqueue_ops: List of enqueue ops to run in threads later.
close_op: Op to close the queue. Pending enqueue ops are preserved.
cancel_op: Op to close the queue and cancel pending enqueue ops.
queue_closed_exception_types: Optional tuple of Exception types that
indicate that the queue has been closed when raised during an enqueue
operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common
case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
when some of the enqueue ops may dequeue from other Queues.
queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
recreates the QueueRunner from its contents. `queue_runner_def` and the
other arguments are mutually exclusive.
import_scope: Optional `string`. Name scope to add. Only used when
initializing from protocol buffer.
Raises:
ValueError: If both `queue_runner_def` and `queue` are both specified.
ValueError: If `queue` or `enqueue_ops` are not provided when not
restoring from `queue_runner_def`.
"""
if queue_runner_def:
if queue or enqueue_ops:
raise ValueError("queue_runner_def and queue are mutually exclusive.")
self._init_from_proto(queue_runner_def,
import_scope=import_scope)
else:
self._init_from_args(
queue=queue, enqueue_ops=enqueue_ops,
close_op=close_op, cancel_op=cancel_op,
queue_closed_exception_types=queue_closed_exception_types)
# Protect the count of runs to wait for.
self._lock = threading.Lock()
# A map from a session object to the number of outstanding queue runner
# threads for that session.
self._runs_per_session = weakref.WeakKeyDictionary()
# List of exceptions raised by the running threads.
self._exceptions_raised = []
tf.train.add_queue_runner 函数: https://www.tensorflow.org/api_docs/python/tf/train/add_queue_runner
add_queue_runner(
qr,
collection=tf.GraphKeys.QUEUE_RUNNERS
)
将 QueueRunner 加入到指定集合中,默认为 tf.GraphKeys.QUEUE_RENNERS 集合
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Adds a `QueueRunner` to a collection in the graph.
When building a complex model that uses many queues it is often difficult to
gather all the queue runners that need to be run. This convenience function
allows you to add a queue runner to a well known collection in the graph.
The companion method `start_queue_runners()` can be used to start threads for
all the collected queue runners.
Args:
qr: A `QueueRunner`.
collection: A `GraphKey` specifying the graph collection to add
the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`.
"""
ops.add_to_collection(collection, qr)
tf.train.Coordinator 函数: https://www.tensorflow.org/api_docs/python/tf/train/Coordinator
Coordinator 是用来协调多个线程一同停止的类,提供了 shoud_stop、request_stop、join三个函数
tensorflow/python/training/coordinator.py
def __init__(self, clean_stop_exception_types=None):
"""Create a new Coordinator.
Args:
clean_stop_exception_types: Optional tuple of Exception types that should
cause a clean stop of the coordinator. If an exception of one of these
types is reported to `request_stop(ex)` the coordinator will behave as
if `request_stop(None)` was called. Defaults to
`(tf.errors.OutOfRangeError,)` which is used by input queues to signal
the end of input. When feeding training data from a Python iterator it
is common to add `StopIteration` to this list.
"""
if clean_stop_exception_types is None:
clean_stop_exception_types = (errors.OutOfRangeError,)
self._clean_stop_exception_types = tuple(clean_stop_exception_types)
# Protects all attributes.
self._lock = threading.Lock()
# Event set when threads must stop.
self._stop_event = threading.Event()
# Python exc_info to report.
# If not None, it should hold the returned value of sys.exc_info(), which is
# a tuple containing exception (type, value, traceback).
self._exc_info_to_raise = None
# True if we have called join() already.
self._joined = False
# Set of threads registered for joining when join() is called. These
# threads will be joined in addition to the threads passed to the join()
# call. It's ok if threads are both registered and passed to the join()
# call.
self._registered_threads = set()
tf.train.start_queue_runners 函数: https://www.tensorflow.org/api_docs/python/tf/train/start_queue_runners
start_queue_runners 函数默认会启动 tf.GraphKeys.QUEUE_RENNERS 集合中的所有 QueueRunner ,与 add_queue_runner 配套操作同一个集合
tensorflow\python\training\queue_runner_impl.py
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Starts all queue runners collected in the graph.
This is a companion method to `add_queue_runner()`. It just starts
threads for all queue runners collected in the graph. It returns
the list of all threads.
Args:
sess: `Session` used to run the queue ops. Defaults to the
default session.
coord: Optional `Coordinator` for coordinating the started threads.
daemon: Whether the threads should be marked as `daemons`, meaning
they don't block program exit.
start: Set to `False` to only create the threads, not start them.
collection: A `GraphKey` specifying the graph collection to
get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
Raises:
ValueError: if `sess` is None and there isn't any default session.
TypeError: if `sess` is not a `tf.Session` object.
Returns:
A list of threads.
"""
if sess is None:
sess = ops.get_default_session()
if not sess:
raise ValueError("Cannot start queue runners: No default session is "
"registered. Use `with sess.as_default()` or pass an "
"explicit session to tf.start_queue_runners(sess=sess)")
if not isinstance(sess, session.SessionInterface):
# Following check is due to backward compatibility. (b/62061352)
if sess.__class__.__name__ in [
"MonitoredSession", "SingularMonitoredSession"]:
return []
raise TypeError("sess must be a `tf.Session` object. "
"Given class: {}".format(sess.__class__))
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):
threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
start=start))
return threads