slim.dataset_data_provider.DatasetDataProvider当指定num_epochs为1时报错:OutOfRangeError: FIFOQueue '_2_par

1:完成错误描述

当在使用slim.dataset_data_provider.DatasetDataProvider这个函数时,参数指定为下面这样就可以跑成功

provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
                                                              num_readers=1,
                                                              shuffle=False,
                                                              common_queue_capacity=20*FLAG.batch_size,
                                                              common_queue_min=10*FLAG.batch_size,
                                                              num_epochs=1,
                                                              seed=None)

但是当把其中的 num_epochs 设定为1的时候,就会报错

tensorflow.python.framework.errors_impl.OutOfRangeError: FIFOQueue '_2_parallel_read/common_queue' is closed and has insufficient elements (requested 1, current size 0)
         [[Node: parallel_read/common_queue_Dequeue = QueueDequeueV2[component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](parallel_read/common_queue)]]

刚开始的时候,看错误像是数据源出错了,但是因为设定为None的时候不报错,说明并不是数据源的错误

2:从源码出发解决问题

https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/contrib/slim/python/slim/data/dataset_data_provider.py

 

    key, data = parallel_reader.parallel_read(
        dataset.data_sources,
        reader_class=dataset.reader,
        num_epochs=num_epochs,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs,
        shuffle=shuffle,
        capacity=common_queue_capacity,
        min_after_dequeue=common_queue_min,
        seed=seed,
        scope=scope)

从源码里面看,这个函数调用了 parallel_reader 这个函数,所以找到 parallel_reader 这个函数的源码

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/data/parallel_reader.py

 

def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128,
                  seed=None,
                  scope=None):
  """Reads multiple records in parallel from data_sources using n readers.
  It uses a ParallelReader to read from multiple files in parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.
  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.
  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)
  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
      the data will be cycled through indefinitely.

 从注释中能看出,num_epochs就是设定每个数据集应该被跑几轮的参数,设置为1 2 3等都不会出错。再继续往下看这个函数的源码

data_files = get_data_files(data_sources)
  with ops.name_scope(scope, 'parallel_read'):
    filename_queue = tf_input.string_input_producer(
        data_files,
        num_epochs=num_epochs,
        shuffle=shuffle,
        seed=seed,
        name='filenames'

 根据注释猜想,data-source应该是一个list变量,但是自己传进来的是一个 record文件的路径,并不是一个list,所以猜想是这块出错了,然后看到函数的第一行 有个 data_files = get_data_files(data_sources),进入到get_data_files这个函数看逻辑后发现

def get_data_files(data_sources):
  """Get data_files from data_sources.
  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
  Returns:
    a list of data_files.
  Raises:
    ValueError: if data files are not found
  """
  if isinstance(data_sources, (list, tuple)):
    data_files = []
    for source in data_sources:
      data_files += get_data_files(source)
  else:
    if '*' in data_sources or '?' in data_sources or '[' in data_sources:
      data_files = gfile.Glob(data_sources)
    else:
      data_files = [data_sources]
  if not data_files:
    raise ValueError('No data files found in %s' % (data_sources,))
  return data_files

我传进来的 data_sources 虽然是一个字符串,但是他会封装成一个list,而且可以看到,这个函数也支持你传进来的 字符串 带 * ,他会把这个目录下能匹配上的都匹配上的都匹上

2-6)tf.gfile.Glob(filename)

查找匹配pattern的文件并以列表的形式返回,filename可以是一个具体的文件名,也可以是包含通配符的正则表达式

所以当数据量特别大的时候,你可以多存储几个 record

而且看源码,slim.dataset_data_provider.DatasetDataProvider 最终还是调用了 tf_input.string_input_producer ,这个只是对其做了个封装,或许直接选择 tf_input.string_input_producer 才是最方便的

3:最终是如何解决问题的

后来发现是因为我没有初始化 local 变量,造成的这个错误,这样解决就好了

sess.run(tf.tables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

 

你可能感兴趣的:(深度学习)