最近在微调之前的算法模型,尴尬的是之前准的原始数据丢失了,只保存的生成的tfrecords,所以这里就需要解析它,来找回原始数据,在网上找了不少解析tfrecord的实例,但是总是会报各种不同的错误.这里直接自己写了一个.
第一步:# 获取TFRecord文件的特征属性以及行数
import tensorflow as tf
def getTFRecordFormat(files):
with tf.Session() as sess:
# 加载TFRecord数据
ds = tf.data.TFRecordDataset(files)
ds = ds.batch(1)
ds = ds.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
iterator = ds.make_one_shot_iterator()
# 为了加快速度,仅仅简单拿一组数据看下结构
batch_data = iterator.get_next()
while True:
res = sess.run(batch_data)
for serialized_example in res:
example_proto = tf.train.Example.FromString(serialized_example)
features = example_proto.features
for key in features.feature:
feature = features.feature[key]
if len(feature.bytes_list.value) > 0:
ftype = 'bytes_list'
fvalue = feature.bytes_list.value
if len(feature.float_list.value) > 0:
ftype = 'float_list'
fvalue = feature.float_list.value
if len(feature.int64_list.value) > 0:
ftype = 'int64_list'
fvalue = feature.int64_list.value
result = '{0} : {1} {2} {3}'.format(key, ftype, len(fvalue),fvalue)
print(result)
break
print("*"*20)
break
# getTFRecordFormat('./train.tfrecords')
第二步:根据获得的属性明来解析tfrecords
# # 读取tfrecord
import tensorflow as tf
import cv2
import numpy as np
dataset = tf.data.TFRecordDataset('./train.tfrecords')
# dataset = tf.data.TFRecordDataset('/home/s2/shared_dir/PROJECTS/OD_200716_face_s2/data/face_valid.tfrecords')
# feature 是一个key-value的键值对,其中key 是string类型,value的取值有三种bytes_list(tf.string),float_list(tf.float32),int64_list(tf.int64)
# bytes_list: 可以存储string和byte两种数据类型
# float_list: 可以存储float(float32)与double(float64)两种数据类型
# int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。
features = {
# value:必须是, tf.string, tf.float32, tf.int64中的一种。
"image/encoded": tf.FixedLenFeature([1], tf.string),
'image/filename': tf.FixedLenFeature((), tf.string),
'image/height': tf.FixedLenFeature((), tf.int64),
'image/width': tf.FixedLenFeature((), tf.int64),
'image/source_id': tf.FixedLenFeature((), tf.string),
'image/object/bbox/ymin': tf.FixedLenFeature((), tf.float32),
'image/object/bbox/ymax': tf.FixedLenFeature((), tf.float32),
'image/object/bbox/xmin': tf.FixedLenFeature((), tf.float32),
'image/object/bbox/xmax': tf.FixedLenFeature((), tf.float32),
'image/object/class/text': tf.FixedLenFeature((), tf.string),
'image/format': tf.FixedLenFeature((), tf.string),
'image/object/class/label': tf.FixedLenFeature((), tf.int64)
}
def _parse_image_function(example_proto):
data = tf.parse_single_example(example_proto, features) # tf.io.parse_single_example 输入是一个string的tensor 输出是一个 dict
data['image/encoded'] = tf.image.decode_image(data['image/encoded'][0], channels=3)# tfrecords中image数据被序列化为bytes类型,解析时通过tf.decode_raw()将其转化为 Tensor张量类型
data['image/filename'] = data["image/filename"]
data['image/height'] = data["image/height"]
data['image/width'] = data["image/width"]
data['image/source_id'] = data["image/source_id"]
data['image/object/bbox/ymin'] = data["image/object/bbox/ymin"]
data['image/object/bbox/ymax'] = data["image/object/bbox/ymax"]
data['image/object/bbox/xmin'] = data["image/object/bbox/xmin"]
data['image/object/bbox/xmax'] = data["image/object/bbox/xmax"]
data['image/object/class/text'] = data["image/object/class/text"]
data['image/format'] = data["image/format"]
data['image/object/class/label'] = data["image/object/class/label"]
return data
images = "./images/"
txt = "./txt/"
dataset = dataset.map(_parse_image_function)
dataset = dataset.batch(1)
# 只支持对Dataset一次迭代,且无需初始化
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
with tf.Session() as sess:
# 含有tf.Variable的环境下,因为tf中建立的变量是没有初始化的,也就是在debug时还不是一个tensor量,而是一个Variable变量类型
# 不含有tf.Variable、tf.get_Variable的环境下,可以不使用初始化
for i in range(65623):
try:
init_op = tf.global_variables_initializer()
sess.run(init_op)
imageInfo = sess.run([data])[0]
img = imageInfo["image/encoded"]
source_id = imageInfo["image/source_id"]
filename = imageInfo["image/filename"]
height = imageInfo["image/height"]
width = imageInfo["image/width"]
ymin = imageInfo["image/object/bbox/ymin"]
ymax = imageInfo["image/object/bbox/ymax"]
xmin = imageInfo["image/object/bbox/xmin"]
xmax = imageInfo["image/object/bbox/xmax"]
text = imageInfo["image/object/class/text"]
format = imageInfo["image/format"]
label = imageInfo["image/object/class/label"]
filename = str(filename[0], 'UTF-8')
print("filename:", filename)
text = str(text[0], 'UTF-8')
imgPath = images + filename.split("/")[-1]
imgs = cv2.cvtColor(img[0], cv2.COLOR_BGR2RGB)
cv2.imwrite(imgPath, imgs)
xmlPath = txt + filename.split("/")[-1][:-3] + "txt"
with open(xmlPath, "a+") as f:
f.write(filename.split("/")[-1] + " " + str(label[0]) + " " + text + " " + str(height[0]) + " " + str(
width[0]) + " " + str(int(ymin[0] * height[0])) + " " + str(int(ymax[0] * height[0])) + " " + str(
int(xmin[0] * width[0])) + " " + str(int(xmax[0] * width[0])) + " " + str(len(ymax)) + "\n")
except:
pass
参考:https://blog.csdn.net/weixin_41558411/article/details/123456957