解析tfrecords

最近在微调之前的算法模型,尴尬的是之前准的原始数据丢失了,只保存的生成的tfrecords,所以这里就需要解析它,来找回原始数据,在网上找了不少解析tfrecord的实例,但是总是会报各种不同的错误.这里直接自己写了一个.
第一步:# 获取TFRecord文件的特征属性以及行数

import tensorflow as tf
def getTFRecordFormat(files):
    with tf.Session() as sess:
        # 加载TFRecord数据
        ds = tf.data.TFRecordDataset(files)
        ds = ds.batch(1)
        ds = ds.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
        iterator = ds.make_one_shot_iterator()
        # 为了加快速度,仅仅简单拿一组数据看下结构
        batch_data = iterator.get_next()
        while True:
                res = sess.run(batch_data)
                for serialized_example in res:
                    example_proto = tf.train.Example.FromString(serialized_example)
                    features = example_proto.features

                    for key in features.feature:
                        feature = features.feature[key]
                        if len(feature.bytes_list.value) > 0:
                            ftype = 'bytes_list'
                            fvalue = feature.bytes_list.value

                        if len(feature.float_list.value) > 0:
                            ftype = 'float_list'
                            fvalue = feature.float_list.value

                        if len(feature.int64_list.value) > 0:
                            ftype = 'int64_list'
                            fvalue = feature.int64_list.value
                        result = '{0} : {1} {2} {3}'.format(key, ftype, len(fvalue),fvalue)
                        print(result)
                    break
                    print("*"*20)
                break

# getTFRecordFormat('./train.tfrecords')

第二步:根据获得的属性明来解析tfrecords

# # 读取tfrecord
import tensorflow as tf
import cv2
import numpy as np
dataset = tf.data.TFRecordDataset('./train.tfrecords')
# dataset = tf.data.TFRecordDataset('/home/s2/shared_dir/PROJECTS/OD_200716_face_s2/data/face_valid.tfrecords')
# feature 是一个key-value的键值对,其中key 是string类型,value的取值有三种bytes_list(tf.string),float_list(tf.float32),int64_list(tf.int64)
# bytes_list: 可以存储string和byte两种数据类型
# float_list: 可以存储float(float32)与double(float64)两种数据类型
# int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。

features = {
        # value:必须是, tf.string, tf.float32, tf.int64中的一种。
        "image/encoded": tf.FixedLenFeature([1], tf.string),
        'image/filename': tf.FixedLenFeature((), tf.string),
        'image/height': tf.FixedLenFeature((), tf.int64),
        'image/width': tf.FixedLenFeature((), tf.int64),
        'image/source_id': tf.FixedLenFeature((), tf.string),
        'image/object/bbox/ymin': tf.FixedLenFeature((), tf.float32),
        'image/object/bbox/ymax': tf.FixedLenFeature((), tf.float32),
        'image/object/bbox/xmin': tf.FixedLenFeature((), tf.float32),
        'image/object/bbox/xmax': tf.FixedLenFeature((), tf.float32),
        'image/object/class/text': tf.FixedLenFeature((), tf.string),
        'image/format': tf.FixedLenFeature((), tf.string),
        'image/object/class/label': tf.FixedLenFeature((), tf.int64)
    }

def _parse_image_function(example_proto):
    data = tf.parse_single_example(example_proto, features) # tf.io.parse_single_example  输入是一个string的tensor 输出是一个 dict
    data['image/encoded'] = tf.image.decode_image(data['image/encoded'][0], channels=3)# tfrecords中image数据被序列化为bytes类型,解析时通过tf.decode_raw()将其转化为 Tensor张量类型
    data['image/filename'] = data["image/filename"]
    data['image/height'] = data["image/height"]
    data['image/width'] = data["image/width"]
    data['image/source_id'] = data["image/source_id"]
    data['image/object/bbox/ymin'] = data["image/object/bbox/ymin"]
    data['image/object/bbox/ymax'] = data["image/object/bbox/ymax"]
    data['image/object/bbox/xmin'] = data["image/object/bbox/xmin"]
    data['image/object/bbox/xmax'] = data["image/object/bbox/xmax"]
    data['image/object/class/text'] = data["image/object/class/text"]
    data['image/format'] = data["image/format"]
    data['image/object/class/label'] = data["image/object/class/label"]
    return data


images = "./images/"
txt = "./txt/"
dataset = dataset.map(_parse_image_function)
dataset = dataset.batch(1)
# 只支持对Dataset一次迭代,且无需初始化
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
with tf.Session() as sess:
    # 含有tf.Variable的环境下,因为tf中建立的变量是没有初始化的,也就是在debug时还不是一个tensor量,而是一个Variable变量类型
    # 不含有tf.Variable、tf.get_Variable的环境下,可以不使用初始化
    for i in range(65623):
        try:
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            imageInfo = sess.run([data])[0]
            img = imageInfo["image/encoded"]
            source_id = imageInfo["image/source_id"]
            filename = imageInfo["image/filename"]
            height = imageInfo["image/height"]
            width = imageInfo["image/width"]
            ymin = imageInfo["image/object/bbox/ymin"]
            ymax = imageInfo["image/object/bbox/ymax"]
            xmin = imageInfo["image/object/bbox/xmin"]
            xmax = imageInfo["image/object/bbox/xmax"]
            text = imageInfo["image/object/class/text"]
            format = imageInfo["image/format"]
            label = imageInfo["image/object/class/label"]
            filename = str(filename[0], 'UTF-8')
            print("filename:", filename)
            text = str(text[0], 'UTF-8')
            imgPath = images + filename.split("/")[-1]
            imgs = cv2.cvtColor(img[0], cv2.COLOR_BGR2RGB)
            cv2.imwrite(imgPath, imgs)
            xmlPath = txt + filename.split("/")[-1][:-3] + "txt"
            with open(xmlPath, "a+") as f:
                f.write(filename.split("/")[-1] + " " + str(label[0]) + " " + text + " " + str(height[0]) + " " + str(
                    width[0]) + " " + str(int(ymin[0] * height[0])) + " " + str(int(ymax[0] * height[0])) + " " + str(
                    int(xmin[0] * width[0])) + " " + str(int(xmax[0] * width[0])) + " " + str(len(ymax)) + "\n")

        except:
            pass



参考:https://blog.csdn.net/weixin_41558411/article/details/123456957

你可能感兴趣的:(python,深度学习,tensorflow)