tf.train. string_input_producer & QueueRunner & add_queue_runner & Coordinator & start_queue_runners

____tz_zs


本文中包含的方法有:

tf.train.string_input_producer

tf.train.input_producer

tf.train.QueueRunner

tf.train.add_queue_runner

tf.train.Coordinator

tf.train.start_queue_runners


string_input_producer 

tf.train.string_input_producer 函数会使用提供的文件列表(string_tensor)创建一个输入队列

string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)


string_tensor 参数 为提供的文件列表

num_epochshs 参数 用于限制加载出事文件列表的最大轮数。当为 None 时,输入队列中的所有文件都被处理后,会将初始化时提供的文件列表的文件全部重新加入队列。当设置了此参数后,会在本地计数,加载次数结束后会报 OutOfRange 错误

shuffle 参数 为True时,文件在加入队列之前会被打乱顺序,一般解决真实问题时,需要减少无关元素的影响,所以为 False

capacity 参数 为队列的容量


tf.train.string_input_producer 函数 https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer

源码位置 tensorflow/python/training/input.py

def string_input_producer(string_tensor,
                          num_epochs=None,
                          shuffle=True,
                          seed=None,
                          capacity=32,
                          shared_name=None,
                          name=None,
                          cancel_op=None):
  """Output strings (e.g. filenames) to a queue for an input pipeline.

  Note: if `num_epochs` is not `None`, this function creates local counter
  `epochs`. Use `local_variables_initializer()` to initialize local variables.

  Args:
    string_tensor: A 1-D string tensor with the strings to produce.
    num_epochs: An integer (optional). If specified, `string_input_producer`
      produces each string from `string_tensor` `num_epochs` times before
      generating an `OutOfRange` error. If not specified,
      `string_input_producer` can cycle through the strings in `string_tensor`
      an unlimited number of times.
    shuffle: Boolean. If true, the strings are randomly shuffled within each
      epoch.
    seed: An integer (optional). Seed used if shuffle == True.
    capacity: An integer. Sets the queue capacity.
    shared_name: (optional). If set, this queue will be shared under the given
      name across multiple sessions. All sessions open to the device which has
      this queue will be able to access it via the shared_name. Using this in
      a distributed setting means each name will only be seen by one of the
      sessions which has access to this operation.
    name: A name for the operations (optional).
    cancel_op: Cancel op for the queue (optional).

  Returns:
    A queue with the output strings.  A `QueueRunner` for the Queue
    is added to the current `Graph`'s `QUEUE_RUNNER` collection.

  Raises:
    ValueError: If the string_tensor is a null Python list.  At runtime,
    will fail with an assertion if string_tensor becomes a null tensor.
  """
  not_null_err = "string_input_producer requires a non-null input tensor"
  if not isinstance(string_tensor, ops.Tensor) and not string_tensor:
    raise ValueError(not_null_err)

  with ops.name_scope(name, "input_producer", [string_tensor]) as name:
    string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string)
    with ops.control_dependencies([
        control_flow_ops.Assert(
            math_ops.greater(array_ops.size(string_tensor), 0),
            [not_null_err])]):
      string_tensor = array_ops.identity(string_tensor)
    return input_producer(
        input_tensor=string_tensor,
        element_shape=[],
        num_epochs=num_epochs,
        shuffle=shuffle,
        seed=seed,
        capacity=capacity,
        shared_name=shared_name,
        name=name,
        summary_name="fraction_of_%d_full" % capacity,
        cancel_op=cancel_op)


由上面 string_input_producer 的定义可知,将 string_tensor(文件地址列表) 整理成 input_tensor 后,调用了 input_producer 函数。其他参数为发生变化。


input_producer 

tf.train.input_producer 函数 https://www.tensorflow.org/api_docs/python/tf/train/input_producer

def input_producer(input_tensor,
                   element_shape=None,
                   num_epochs=None,
                   shuffle=True,
                   seed=None,
                   capacity=32,
                   shared_name=None,
                   summary_name=None,
                   name=None,
                   cancel_op=None):
  """Output the rows of `input_tensor` to a queue for an input pipeline.

  Note: if `num_epochs` is not `None`, this function creates local counter
  `epochs`. Use `local_variables_initializer()` to initialize local variables.

  Args:
    input_tensor: A tensor with the rows to produce. Must be at least
      one-dimensional. Must either have a fully-defined shape, or
      `element_shape` must be defined.
    element_shape: (Optional.) A `TensorShape` representing the shape of a
      row of `input_tensor`, if it cannot be inferred.
    num_epochs: (Optional.) An integer. If specified `input_producer` produces
      each row of `input_tensor` `num_epochs` times before generating an
      `OutOfRange` error. If not specified, `input_producer` can cycle through
      the rows of `input_tensor` an unlimited number of times.
    shuffle: (Optional.) A boolean. If true, the rows are randomly shuffled
      within each epoch.
    seed: (Optional.) An integer. The seed to use if `shuffle` is true.
    capacity: (Optional.) The capacity of the queue to be used for buffering
      the input.
    shared_name: (Optional.) If set, this queue will be shared under the given
      name across multiple sessions.
    summary_name: (Optional.) If set, a scalar summary for the current queue
      size will be generated, using this name as part of the tag.
    name: (Optional.) A name for queue.
    cancel_op: (Optional.) Cancel op for the queue

  Returns:
    A queue with the output rows.  A `QueueRunner` for the queue is
    added to the current `QUEUE_RUNNER` collection of the current
    graph.

  Raises:
    ValueError: If the shape of the input cannot be inferred from the arguments.
  """
  with ops.name_scope(name, "input_producer", [input_tensor]):
    input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
    element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
    if not element_shape.is_fully_defined():
      raise ValueError("Either `input_tensor` must have a fully defined shape "
                       "or `element_shape` must be specified")

    if shuffle:
      input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)

    input_tensor = limit_epochs(input_tensor, num_epochs)

    q = data_flow_ops.FIFOQueue(capacity=capacity,
                                dtypes=[input_tensor.dtype.base_dtype],
                                shapes=[element_shape],
                                shared_name=shared_name, name=name)
    enq = q.enqueue_many([input_tensor])
    queue_runner.add_queue_runner(
        queue_runner.QueueRunner(
            q, [enq], cancel_op=cancel_op))
    if summary_name is not None:
      summary.scalar(summary_name,
                     math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
    return q

input_producer 函数中的参数意义与 string_input_producer 函数类似,以上 input_producer 函数的定义中,对于 element_shape 和 shuffle 的操作与文章开始参数作用的介绍一致,不再赘述。

重点在于之后的代码中


先创建了一个先进先出的队列,容量为capacity

q = data_flow_ops.FIFOQueue(capacity=capacity,
                                dtypes=[input_tensor.dtype.base_dtype],
                                shapes=[element_shape],
                                shared_name=shared_name, name=name)


然后,定义将 input_tensor 入队到队列的操作

enq = q.enqueue_many([input_tensor])


接下来使用 QueueRunner 创建多个线程运行 前面定义的入队操作,并使用 add_queue_runner 将 QueueRunner 加入到 tf.GraphKeys.QUEUE_RENNERS 集合中

queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))


QueueRunner

tf.train.QueueRunner 函数: https://www.tensorflow.org/api_docs/python/tf/train/QueueRunner

tf.train.QueueRunner 函数一般用来启动多个线程操作同一个队列

  def __init__(self, queue=None, enqueue_ops=None, close_op=None,
               cancel_op=None, queue_closed_exception_types=None,
               queue_runner_def=None, import_scope=None):
    """Create a QueueRunner.

    On construction the `QueueRunner` adds an op to close the queue.  That op
    will be run if the enqueue ops raise exceptions.

    When you later call the `create_threads()` method, the `QueueRunner` will
    create one thread for each op in `enqueue_ops`.  Each thread will run its
    enqueue op in parallel with the other threads.  The enqueue ops do not have
    to all be the same op, but it is expected that they all enqueue tensors in
    `queue`.

    Args:
      queue: A `Queue`.
      enqueue_ops: List of enqueue ops to run in threads later.
      close_op: Op to close the queue. Pending enqueue ops are preserved.
      cancel_op: Op to close the queue and cancel pending enqueue ops.
      queue_closed_exception_types: Optional tuple of Exception types that
        indicate that the queue has been closed when raised during an enqueue
        operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
        case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
        when some of the enqueue ops may dequeue from other Queues.
      queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
        recreates the QueueRunner from its contents. `queue_runner_def` and the
        other arguments are mutually exclusive.
      import_scope: Optional `string`. Name scope to add. Only used when
        initializing from protocol buffer.

    Raises:
      ValueError: If both `queue_runner_def` and `queue` are both specified.
      ValueError: If `queue` or `enqueue_ops` are not provided when not
        restoring from `queue_runner_def`.
    """
    if queue_runner_def:
      if queue or enqueue_ops:
        raise ValueError("queue_runner_def and queue are mutually exclusive.")
      self._init_from_proto(queue_runner_def,
                            import_scope=import_scope)
    else:
      self._init_from_args(
          queue=queue, enqueue_ops=enqueue_ops,
          close_op=close_op, cancel_op=cancel_op,
          queue_closed_exception_types=queue_closed_exception_types)
    # Protect the count of runs to wait for.
    self._lock = threading.Lock()
    # A map from a session object to the number of outstanding queue runner
    # threads for that session.
    self._runs_per_session = weakref.WeakKeyDictionary()
    # List of exceptions raised by the running threads.
    self._exceptions_raised = []

add_queue_runner 

tf.train.add_queue_runner 函数: https://www.tensorflow.org/api_docs/python/tf/train/add_queue_runner

add_queue_runner(
    qr,
    collection=tf.GraphKeys.QUEUE_RUNNERS
)

将 QueueRunner 加入到指定集合中,默认为  tf.GraphKeys.QUEUE_RENNERS 集合

def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
  """Adds a `QueueRunner` to a collection in the graph.

  When building a complex model that uses many queues it is often difficult to
  gather all the queue runners that need to be run.  This convenience function
  allows you to add a queue runner to a well known collection in the graph.

  The companion method `start_queue_runners()` can be used to start threads for
  all the collected queue runners.

  Args:
    qr: A `QueueRunner`.
    collection: A `GraphKey` specifying the graph collection to add
      the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
  """
  ops.add_to_collection(collection, qr)

Coordinator

tf.train.Coordinator 函数: https://www.tensorflow.org/api_docs/python/tf/train/Coordinator

Coordinator 是用来协调多个线程一同停止的类,提供了 shoud_stop、request_stop、join三个函数

tensorflow/python/training/coordinator.py

  def __init__(self, clean_stop_exception_types=None):
    """Create a new Coordinator.

    Args:
      clean_stop_exception_types: Optional tuple of Exception types that should
        cause a clean stop of the coordinator. If an exception of one of these
        types is reported to `request_stop(ex)` the coordinator will behave as
        if `request_stop(None)` was called.  Defaults to
        `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
        the end of input. When feeding training data from a Python iterator it
        is common to add `StopIteration` to this list.
    """
    if clean_stop_exception_types is None:
      clean_stop_exception_types = (errors.OutOfRangeError,)
    self._clean_stop_exception_types = tuple(clean_stop_exception_types)
    # Protects all attributes.
    self._lock = threading.Lock()
    # Event set when threads must stop.
    self._stop_event = threading.Event()
    # Python exc_info to report.
    # If not None, it should hold the returned value of sys.exc_info(), which is
    # a tuple containing exception (type, value, traceback).
    self._exc_info_to_raise = None
    # True if we have called join() already.
    self._joined = False
    # Set of threads registered for joining when join() is called.  These
    # threads will be joined in addition to the threads passed to the join()
    # call.  It's ok if threads are both registered and passed to the join()
    # call.
    self._registered_threads = set()

start_queue_runners

tf.train.start_queue_runners 函数: https://www.tensorflow.org/api_docs/python/tf/train/start_queue_runners

start_queue_runners 函数默认会启动 tf.GraphKeys.QUEUE_RENNERS 集合中的所有 QueueRunner ,与 add_queue_runner 配套操作同一个集合

tensorflow\python\training\queue_runner_impl.py

def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
                        collection=ops.GraphKeys.QUEUE_RUNNERS):
  """Starts all queue runners collected in the graph.

  This is a companion method to `add_queue_runner()`.  It just starts
  threads for all queue runners collected in the graph.  It returns
  the list of all threads.

  Args:
    sess: `Session` used to run the queue ops.  Defaults to the
      default session.
    coord: Optional `Coordinator` for coordinating the started threads.
    daemon: Whether the threads should be marked as `daemons`, meaning
      they don't block program exit.
    start: Set to `False` to only create the threads, not start them.
    collection: A `GraphKey` specifying the graph collection to
      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.

  Raises:
    ValueError: if `sess` is None and there isn't any default session.
    TypeError: if `sess` is not a `tf.Session` object.

  Returns:
    A list of threads.
  """
  if sess is None:
    sess = ops.get_default_session()
    if not sess:
      raise ValueError("Cannot start queue runners: No default session is "
                       "registered. Use `with sess.as_default()` or pass an "
                       "explicit session to tf.start_queue_runners(sess=sess)")

  if not isinstance(sess, session.SessionInterface):
    # Following check is due to backward compatibility. (b/62061352)
    if sess.__class__.__name__ in [
        "MonitoredSession", "SingularMonitoredSession"]:
      return []
    raise TypeError("sess must be a `tf.Session` object. "
                    "Given class: {}".format(sess.__class__))

  with sess.graph.as_default():
    threads = []
    for qr in ops.get_collection(collection):
      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
                                       start=start))
  return threads


你可能感兴趣的:(#,TensorFlow)