SSD算法中TFRecord格式数据的读取

目录

一、使用slim读取TFRecord格式文件的步骤

1、设置解码器

2、定义数据集

3、定义数据集的数据提供者类

4、调用provider的get方法从items_to_tensors中获取响应的items对应的tensor


一、使用slim读取TFRecord格式文件的步骤

1、设置解码器

一般设置解码器为 slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers),keys_to_features这个参数是一个字典,这个字典要和TFRecord文件中定义的字典匹配。items_to_handlers的关键字可以是任意的,但是handler的初始化参数必须要来自keys_to_features中的关键字。

keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/filename': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'image/filename': slim.tfexample_decoder.Tensor('image/filename'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }

2、定义数据集

使用函数 slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,
                                                                                     num_samples=split_to_sizes[split_name],
                                                                                     items_to_descriptions=items_to_descriptions,
                                                                                     num_classes=num_classes,
                                                                                     labels_to_names=labels_to_names)

它把datasource、reader、decoder、num_samples等参数封装好。

3、定义数据集的数据提供者类

通常使用函数 provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers,                                                                                          common_queue_capacity=20 * FLAGS.batch_size,                                                                                                                          common_queue_min=10 * FLAGS.batch_size, shuffle=True)

1)在这个类中首先调用的是  _ ,data=parallel_reader.parallel_read(),这个方法调用tf.train.string_input_producer() 得到TFRecord的文件队列(filename_queue),然后根据是否shuffle生成一个公共队列(common queue)用reader_class,common_queue,num_readers,reader_kwargs=reader_kwargs等参数初始化ParallelReader(),然后调用它的read(filename_queuq)方法,这个read()方法先用reader从filename_queue中读取数据然后enqueue到common queue中,然后从common queue中dequeue,从而得到(filename,data)的键值对。

 key, data = parallel_reader.parallel_read(
        dataset.data_sources,
        reader_class=dataset.reader,
        num_epochs=num_epochs,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs,
        shuffle=shuffle,
        capacity=common_queue_capacity,
        min_after_dequeue=common_queue_min,
        seed=seed,
        scope=scope)
data_files = get_data_files(data_sources)
  with ops.name_scope(scope, 'parallel_read'):
    filename_queue = tf_input.string_input_producer(
        data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed,
        name='filenames')
    dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
    if shuffle:
      common_queue = data_flow_ops.RandomShuffleQueue(
          capacity=capacity,
          min_after_dequeue=min_after_dequeue,
          dtypes=dtypes,
          seed=seed,
          name='common_queue')
    else:
      common_queue = data_flow_ops.FIFOQueue(
          capacity=capacity, dtypes=dtypes, name='common_queue')

    summary.scalar('fraction_of_%d_full' % capacity,
                   math_ops.to_float(common_queue.size()) * (1. / capacity))

    return ParallelReader(
        reader_class,
        common_queue,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs).read(filename_queue)


2)其次,调用items=dataset.decoder.list_items()得到decoder中的items_to_handlers的关键字列表items。

items = dataset.decoder.list_items()

3)根据1)和2)得到的data和items,调用tensors=dataset.decoder.decode(data, items)。这解码过程中,首先调用example=parsing_ops.parse_single_example(data,keys_to_features)来解析序列化数据得到一个字典特征,然后根据items_to_handlers中传给handler的那些items(这些items来自keys_to_features中的keys),将example中的字典中属于某个handler的多个键值对(因为一个handler用多个items初始化,所以一个handler对应example中多个键值对)交给相应的handler处理,然后每个handler处理完成后返回一个tensor,将所有tensor组成一个列表tensors。

tensors = dataset.decoder.decode(data, items)

def decode(self, serialized_example, items=None):
    """Decodes the given serialized TF-example.

    Args:
      serialized_example: a serialized TF-example tensor.
      items: the list of items to decode. These must be a subset of the item
        keys in self._items_to_handlers. If `items` is left as None, then all
        of the items in self._items_to_handlers are decoded.

    Returns:
      the decoded items, a list of tensor.
    """
    example = parsing_ops.parse_single_example(serialized_example,
                                               self._keys_to_features)

    # Reshape non-sparse elements just once, adding the reshape ops in
    # deterministic order.
    for k in sorted(self._keys_to_features):
      v = self._keys_to_features[k]
      if isinstance(v, parsing_ops.FixedLenFeature):
        example[k] = array_ops.reshape(example[k], v.shape)

    if not items:
      items = self._items_to_handlers.keys()

    outputs = []
    for item in items:
      handler = self._items_to_handlers[item]
      keys_to_tensors = {key: example[key] for key in handler.keys}
      outputs.append(handler.tensors_to_item(keys_to_tensors))
    return outputs



def parse_single_example(serialized, features, name=None, example_names=None):


    if not features:
    raise ValueError("Missing features.")
  if example_names is None:
    return parse_single_example_v2(serialized, features, name)
  features = _prepend_none_dimension(features)
  (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
   dense_shapes) = _features_to_raw_params(
       features,
       [VarLenFeature, FixedLenFeature, FixedLenSequenceFeature, SparseFeature])
  outputs = _parse_single_example_raw(
      serialized, example_names, sparse_keys, sparse_types, dense_keys,
      dense_types, dense_defaults, dense_shapes, name)
  return _construct_sparse_tensors_for_sparse_features(features, outputs)


4)然后将2)中得到的items和3)中得到的tensors进行匹配生成一个字典items_to_tensors。

4、调用provider的get方法从items_to_tensors中获取响应的items对应的tensor

[image, shape, glabels, gbboxes,xml_pic_name] = provider.get(['image', 'shape',  #获取图片的真实label和标注的框
                                                             'object/label',
                                                             'object/bbox','image/filename'])  #从TFRecord文件中获取数据

 

你可能感兴趣的:(SSD算法)