1:完成错误描述
当在使用slim.dataset_data_provider.DatasetDataProvider这个函数时,参数指定为下面这样就可以跑成功
provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
num_readers=1,
shuffle=False,
common_queue_capacity=20*FLAG.batch_size,
common_queue_min=10*FLAG.batch_size,
num_epochs=1,
seed=None)
但是当把其中的 num_epochs 设定为1的时候,就会报错
tensorflow.python.framework.errors_impl.OutOfRangeError: FIFOQueue '_2_parallel_read/common_queue' is closed and has insufficient elements (requested 1, current size 0)
[[Node: parallel_read/common_queue_Dequeue = QueueDequeueV2[component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](parallel_read/common_queue)]]
刚开始的时候,看错误像是数据源出错了,但是因为设定为None的时候不报错,说明并不是数据源的错误
2:从源码出发解决问题
https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/contrib/slim/python/slim/data/dataset_data_provider.py
key, data = parallel_reader.parallel_read(
dataset.data_sources,
reader_class=dataset.reader,
num_epochs=num_epochs,
num_readers=num_readers,
reader_kwargs=reader_kwargs,
shuffle=shuffle,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
seed=seed,
scope=scope)
从源码里面看,这个函数调用了 parallel_reader 这个函数,所以找到 parallel_reader 这个函数的源码
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
def parallel_read(data_sources,
reader_class,
num_epochs=None,
num_readers=4,
reader_kwargs=None,
shuffle=True,
dtypes=None,
capacity=256,
min_after_dequeue=128,
seed=None,
scope=None):
"""Reads multiple records in parallel from data_sources using n readers.
It uses a ParallelReader to read from multiple files in parallel using
multiple readers created using `reader_class` with `reader_kwargs'.
If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
it would be a FIFOQueue.
Usage:
data_sources = ['path_to/train*']
key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)
Args:
data_sources: a list/tuple of files or the location of the data, i.e.
/path/to/train@128, /path/to/train* or /tmp/.../train*
reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
num_epochs: The number of times each data source is read. If left as None,
the data will be cycled through indefinitely.
从注释中能看出,num_epochs就是设定每个数据集应该被跑几轮的参数,设置为1 2 3等都不会出错。再继续往下看这个函数的源码
data_files = get_data_files(data_sources)
with ops.name_scope(scope, 'parallel_read'):
filename_queue = tf_input.string_input_producer(
data_files,
num_epochs=num_epochs,
shuffle=shuffle,
seed=seed,
name='filenames'
根据注释猜想,data-source应该是一个list变量,但是自己传进来的是一个 record文件的路径,并不是一个list,所以猜想是这块出错了,然后看到函数的第一行 有个 data_files = get_data_files(data_sources),进入到get_data_files这个函数看逻辑后发现
def get_data_files(data_sources):
"""Get data_files from data_sources.
Args:
data_sources: a list/tuple of files or the location of the data, i.e.
/path/to/train@128, /path/to/train* or /tmp/.../train*
Returns:
a list of data_files.
Raises:
ValueError: if data files are not found
"""
if isinstance(data_sources, (list, tuple)):
data_files = []
for source in data_sources:
data_files += get_data_files(source)
else:
if '*' in data_sources or '?' in data_sources or '[' in data_sources:
data_files = gfile.Glob(data_sources)
else:
data_files = [data_sources]
if not data_files:
raise ValueError('No data files found in %s' % (data_sources,))
return data_files
我传进来的 data_sources 虽然是一个字符串,但是他会封装成一个list,而且可以看到,这个函数也支持你传进来的 字符串 带 * ,他会把这个目录下能匹配上的都匹配上的都匹上
2-6)tf.gfile.Glob(filename)
查找匹配pattern的文件并以列表的形式返回,filename可以是一个具体的文件名,也可以是包含通配符的正则表达式
所以当数据量特别大的时候,你可以多存储几个 record
而且看源码,slim.dataset_data_provider.DatasetDataProvider 最终还是调用了 tf_input.string_input_producer ,这个只是对其做了个封装,或许直接选择 tf_input.string_input_producer 才是最方便的
3:最终是如何解决问题的
后来发现是因为我没有初始化 local 变量,造成的这个错误,这样解决就好了
sess.run(tf.tables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())