目录
一、使用slim读取TFRecord格式文件的步骤
1、设置解码器
2、定义数据集
3、定义数据集的数据提供者类
4、调用provider的get方法从items_to_tensors中获取响应的items对应的tensor
一般设置解码器为 slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers),keys_to_features这个参数是一个字典,这个字典要和TFRecord文件中定义的字典匹配。items_to_handlers的关键字可以是任意的,但是handler的初始化参数必须要来自keys_to_features中的关键字。
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature([1], tf.int64),
'image/width': tf.FixedLenFeature([1], tf.int64),
'image/channels': tf.FixedLenFeature([1], tf.int64),
'image/shape': tf.FixedLenFeature([3], tf.int64),
'image/filename': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
'shape': slim.tfexample_decoder.Tensor('image/shape'),
'object/bbox': slim.tfexample_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
'image/filename': slim.tfexample_decoder.Tensor('image/filename'),
'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
}
使用函数 slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,
num_samples=split_to_sizes[split_name],
items_to_descriptions=items_to_descriptions,
num_classes=num_classes,
labels_to_names=labels_to_names)
它把datasource、reader、decoder、num_samples等参数封装好。
通常使用函数 provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size, shuffle=True)
1)在这个类中首先调用的是 _ ,data=parallel_reader.parallel_read(),这个方法调用tf.train.string_input_producer() 得到TFRecord的文件队列(filename_queue),然后根据是否shuffle生成一个公共队列(common queue)用reader_class,common_queue,num_readers,reader_kwargs=reader_kwargs等参数初始化ParallelReader(),然后调用它的read(filename_queuq)方法,这个read()方法先用reader从filename_queue中读取数据然后enqueue到common queue中,然后从common queue中dequeue,从而得到(filename,data)的键值对。
key, data = parallel_reader.parallel_read(
dataset.data_sources,
reader_class=dataset.reader,
num_epochs=num_epochs,
num_readers=num_readers,
reader_kwargs=reader_kwargs,
shuffle=shuffle,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
seed=seed,
scope=scope)
data_files = get_data_files(data_sources)
with ops.name_scope(scope, 'parallel_read'):
filename_queue = tf_input.string_input_producer(
data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed,
name='filenames')
dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
if shuffle:
common_queue = data_flow_ops.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=dtypes,
seed=seed,
name='common_queue')
else:
common_queue = data_flow_ops.FIFOQueue(
capacity=capacity, dtypes=dtypes, name='common_queue')
summary.scalar('fraction_of_%d_full' % capacity,
math_ops.to_float(common_queue.size()) * (1. / capacity))
return ParallelReader(
reader_class,
common_queue,
num_readers=num_readers,
reader_kwargs=reader_kwargs).read(filename_queue)
2)其次,调用items=dataset.decoder.list_items()得到decoder中的items_to_handlers的关键字列表items。
items = dataset.decoder.list_items()
3)根据1)和2)得到的data和items,调用tensors=dataset.decoder.decode(data, items)。这解码过程中,首先调用example=parsing_ops.parse_single_example(data,keys_to_features)来解析序列化数据得到一个字典特征,然后根据items_to_handlers中传给handler的那些items(这些items来自keys_to_features中的keys),将example中的字典中属于某个handler的多个键值对(因为一个handler用多个items初始化,所以一个handler对应example中多个键值对)交给相应的handler处理,然后每个handler处理完成后返回一个tensor,将所有tensor组成一个列表tensors。
tensors = dataset.decoder.decode(data, items)
def decode(self, serialized_example, items=None):
"""Decodes the given serialized TF-example.
Args:
serialized_example: a serialized TF-example tensor.
items: the list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
the decoded items, a list of tensor.
"""
example = parsing_ops.parse_single_example(serialized_example,
self._keys_to_features)
# Reshape non-sparse elements just once, adding the reshape ops in
# deterministic order.
for k in sorted(self._keys_to_features):
v = self._keys_to_features[k]
if isinstance(v, parsing_ops.FixedLenFeature):
example[k] = array_ops.reshape(example[k], v.shape)
if not items:
items = self._items_to_handlers.keys()
outputs = []
for item in items:
handler = self._items_to_handlers[item]
keys_to_tensors = {key: example[key] for key in handler.keys}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs
def parse_single_example(serialized, features, name=None, example_names=None):
if not features:
raise ValueError("Missing features.")
if example_names is None:
return parse_single_example_v2(serialized, features, name)
features = _prepend_none_dimension(features)
(sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes) = _features_to_raw_params(
features,
[VarLenFeature, FixedLenFeature, FixedLenSequenceFeature, SparseFeature])
outputs = _parse_single_example_raw(
serialized, example_names, sparse_keys, sparse_types, dense_keys,
dense_types, dense_defaults, dense_shapes, name)
return _construct_sparse_tensors_for_sparse_features(features, outputs)
4)然后将2)中得到的items和3)中得到的tensors进行匹配生成一个字典items_to_tensors。
[image, shape, glabels, gbboxes,xml_pic_name] = provider.get(['image', 'shape', #获取图片的真实label和标注的框
'object/label',
'object/bbox','image/filename']) #从TFRecord文件中获取数据