TFRecord 是TensorFlow专用的数据处理文件,方便在训练的时候快速读取和转移
现在就基于VOC数据集介绍一下。
首先就是封装数据集,其具体方法如下:
with tf.io.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
#写成tfrecord文件
writer = tf.io.TFRecordWriter(output_path)
for idx, example in enumerate(examples_list):
tf_example = dict_to_tf_example(data, data_dir, VOC_NAME_LABEL,
ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
就是定义好解析字典IMAGE_FEATURE_MAP 和解析方法parse_example,就是对出来的数据进行组合处理,最终输出结果
需要主要的是以下2个方面对应
tf.io.FixedLenFeature([], tf.int64) ==> tf.Tensor(375, shape=(), dtype=int64)
tf.io.VarLenFeature(tf.float32) ==> SparseTensor(indices=tf.Tensor([[0]], shape=(1, 1), dtype=int64), values=tf.Tensor([12], shape=(1,), dtype=int64), dense_shape=tf.Tensor([1], shape=(1,), dtype=int64))
#解析对应格式
IMAGE_FEATURE_MAP = {
'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/source_id': tf.io.FixedLenFeature([], tf.string),
'image/key/sha256': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/format': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), # 如果数据中存放的list长度大于1, 表示数据是不定长的, 使用VarLenFeature解析
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/difficult': tf.io.VarLenFeature(tf.int64),
'image/object/truncated': tf.io.VarLenFeature(tf.int64),
'image/object/view': tf.io.VarLenFeature(tf.string),
}
def parse_example(serialized_example,height=512,width=512):
#解析序列化的example
x = tf.io.parse_single_example(serialized_example, IMAGE_FEATURE_MAP)
#然后就可以根据字典获取值了
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (height,width))
# class_text = x['image/object/class/text'] # 原始类型是SparseTensor, https://blog.csdn.net/JsonD/article/details/73105490
# class_text = tf.sparse.to_dense(x['image/object/class/text'], default_value='')
labels = tf.cast(tf.sparse.to_dense(x['image/object/class/label']), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/ymin']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/xmax']), # shape: [m]
tf.sparse.to_dense(x['image/object/bbox/ymax']), # shape: [m]
labels # shape: [m]
], axis=1) # shape:[m, 5], m是图片中目标的个数, 每张图片的m可能不一样
# 每个图片最多包含100个目标
paddings = [[0, 100 - tf.shape(y_train)[0]], [0, 0]] # 上下左右分别填充0, 100 - tf.shape(y_train)[0], 0, 0
# The padded size of each dimension D of the output is:
# paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]
y_train = tf.pad(y_train, paddings)
return x_train, y_train
def _parse_function(example_proto):
# Parse the input `tf.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, IMAGE_FEATURE_MAP)
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset(filenames=['/data/data/VOC2007/train.tfrecord'])
print(dataset)
# raw_eaxmple = next(iter(dataset))
# parsed = tf.train.Example.FromString(raw_eaxmple.numpy())
# print(parsed)
# for index ,record in enumerate(dataset):
# example = tf.io.parse_single_example(record,features=IMAGE_FEATURE_MAP)
# for key,value in example.items():
# print(key,'=>',value)
# parsed_dataset = dataset.map(_parse_function)
parsed_dataset = dataset.map(parse_example) #map就可以对每个序列化的example进行解析
for parsed_record in parsed_dataset.take(10):
# print(repr(parsed_record))
print(repr(parsed_record))
print('=========')
本文主要参考了yinghuang/yolov2-tensorflow2