Tensorflow之TFRecord制作——VOC数据为例

TFRecord 是TensorFlow专用的数据处理文件,方便在训练的时候快速读取和转移
现在就基于VOC数据集介绍一下。

1、生成TFRecord

首先就是封装数据集,其具体方法如下:

Tensorflow之TFRecord制作——VOC数据为例_第1张图片
具体实现代码为:

with tf.io.gfile.GFile(full_path, 'rb') as fid:
      encoded_jpg = fid.read()
      
def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
    return example

#写成tfrecord文件
writer = tf.io.TFRecordWriter(output_path)
for idx, example in enumerate(examples_list):

	tf_example = dict_to_tf_example(data, data_dir, VOC_NAME_LABEL,
	                                            ignore_difficult_instances)
	writer.write(tf_example.SerializeToString())

writer.close()

2、解析TFRecord

就是定义好解析字典IMAGE_FEATURE_MAP 和解析方法parse_example,就是对出来的数据进行组合处理,最终输出结果

需要主要的是以下2个方面对应
tf.io.FixedLenFeature([], tf.int64) ==> tf.Tensor(375, shape=(), dtype=int64)

tf.io.VarLenFeature(tf.float32) ==> SparseTensor(indices=tf.Tensor([[0]], shape=(1, 1), dtype=int64), values=tf.Tensor([12], shape=(1,), dtype=int64), dense_shape=tf.Tensor([1], shape=(1,), dtype=int64))

#解析对应格式
IMAGE_FEATURE_MAP = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/source_id': tf.io.FixedLenFeature([], tf.string),
    'image/key/sha256': tf.io.FixedLenFeature([], tf.string),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), # 如果数据中存放的list长度大于1, 表示数据是不定长的, 使用VarLenFeature解析
    'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
    'image/object/class/text': tf.io.VarLenFeature(tf.string),
    'image/object/class/label': tf.io.VarLenFeature(tf.int64),
    'image/object/difficult': tf.io.VarLenFeature(tf.int64),
    'image/object/truncated': tf.io.VarLenFeature(tf.int64),
    'image/object/view': tf.io.VarLenFeature(tf.string),
}

def parse_example(serialized_example,height=512,width=512):
  #解析序列化的example
  x = tf.io.parse_single_example(serialized_example, IMAGE_FEATURE_MAP)
  #然后就可以根据字典获取值了
  x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
  x_train = tf.image.resize(x_train, (height,width))
#  class_text = x['image/object/class/text'] # 原始类型是SparseTensor, https://blog.csdn.net/JsonD/article/details/73105490
#  class_text = tf.sparse.to_dense(x['image/object/class/text'], default_value='')
  labels = tf.cast(tf.sparse.to_dense(x['image/object/class/label']), tf.float32)
  y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/ymin']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/xmax']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/ymax']), # shape: [m]
                      labels  # shape: [m]
                      ], axis=1) # shape:[m, 5], m是图片中目标的个数, 每张图片的m可能不一样

  # 每个图片最多包含100个目标
  paddings = [[0, 100 - tf.shape(y_train)[0]], [0, 0]] # 上下左右分别填充0, 100 - tf.shape(y_train)[0], 0, 0
  # The padded size of each dimension D of the output is:
  # paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]
  y_train = tf.pad(y_train, paddings)
  return x_train, y_train


def _parse_function(example_proto):
    # Parse the input `tf.Example` proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, IMAGE_FEATURE_MAP)


if __name__ == '__main__':
    dataset = tf.data.TFRecordDataset(filenames=['/data/data/VOC2007/train.tfrecord'])
    print(dataset)
    # raw_eaxmple = next(iter(dataset))
    # parsed = tf.train.Example.FromString(raw_eaxmple.numpy())
    # print(parsed)

    # for index ,record in enumerate(dataset):
    #     example = tf.io.parse_single_example(record,features=IMAGE_FEATURE_MAP)
    #     for key,value in example.items():
    #         print(key,'=>',value)

    # parsed_dataset = dataset.map(_parse_function)
    parsed_dataset = dataset.map(parse_example)  #map就可以对每个序列化的example进行解析

    for parsed_record in parsed_dataset.take(10):
        # print(repr(parsed_record))
        print(repr(parsed_record))
        print('=========')

本文主要参考了yinghuang/yolov2-tensorflow2

你可能感兴趣的:(tensorflow)