TFRecord和tf.Example

TFRecordtf.Example

写tfrecord文件

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets('./data', dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
size = images.shape[1]
num_examples = mnist.train.num_examples

# 输出TFRecord文件的地址
filename = './output.tfrecord'

# 创建writer来写tfrecords文件
writer = tf.io.TFRecordWriter(filename)

for i in range(num_examples):
    # 将图像矩阵转换为一个字符串
    image_raw = images[i].tostring()
    #将一个样例转换为Example Protocol Buffer,并将所有的信息写入这个数据结构
    example = tf.train.Example(features = tf.train.Features(feature={
        'size': int64_feature(size),
        'label': int64_feature(np.argmax(labels[i])),
        'image_raw': bytes_feature(image_raw)
    }))
    #将一个Example写入TFRecord文件
    writer.write(example.SerializeToString())
writer.close()

读取tfrecord文件

import tensorflow as tf

# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['./output.tfrecord'])

# 从文件中读取一个样例,也可以使用read_up_to函数一次性读取多个样例
_, serialized_example = reader.read(filename_queue)

#解析读入的一个样例,如果需要解析多个样例,可以使用parse_example函数
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'size': tf.FixedLenFeature([], tf.int64),
        'label': tf.FixedLenFeature([], tf.int64)})
    
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
size = tf.cast(features['size'], tf.int32)

# 启动多线程处理数据
coord = tf.train.Coordinator()
with tf.Session() as sess:
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    # 每次运行可以读取TFRecord文件中的一个样例,当所有样例都读完之后,在此示例中程序会再重头读取
    for i in range(10):
        print(sess.run([image,label,size]))

tf.Example

import tensorflow as tf
import numpy as np

# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float(float32) / double(float64)."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(feature0, feature1, feature2, feature3):
    """ Creates a tf.Example message ready to be written to a file. """
    # Create a dictionary mapping the feature name to the tf.Example-compatible data type.
    feature = {
        'feature0': _int64_feature(feature0),
        'feature1': _int64_feature(feature1),
        'feature2': _bytes_feature(feature2),
        'feature3': _float_feature(feature3),
        }
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
# 输出:
'''
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}
'''
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
print(serialized_example) # b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'
example_proto = tf.train.Example.FromString(serialized_example)
print(example_proto)
# 输出:
'''
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}
'''

本示例中处理的是单个整型、浮点类型、字节类型,因此value=[value]中对value加了[]使其具有可迭代性,如果需要存储的数据本身就有可迭代性则不能再加[],例如如果是要存储[1.1,1.2,1.3],则对应的函数_float_feature应该写成:

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

完整示例

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float(float32) / double(float64)."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Create a dictionary with features that may be relevant.
# 高版本的TensorFlow支持解码后直接获取shape
def image_example(img_raw, label):
    img_tensor = tf.image.decode_jpeg(img_raw)
    image_shape = img_tensor.shape
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(img_raw),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

# Create a dictionary with features that may be relevant.
# 低版本的TensorFlow不支持解码后直接获取shape,转成numpy.ndarray后在获取
def image_example_sess(img_raw, label, sess):
    img_tensor = tf.image.decode_jpeg(img_raw)
    with sess.as_default():
        img_data = img_tensor.eval()
        print(type(img_data))
        image_shape = img_data.shape
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(img_raw),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

'''
img_raw = tf.gfile.FastGFile('test1.jpg', 'rb').read()
label = 0
with tf.Session() as sess:
    print(image_example_sess(img_raw, label, sess))
'''
#####################################################################################################
# 写入 TFRecord 文件
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
image_labels = {'test1.jpg' : 0, 'test2.jpg': 1}
record_file = 'images.tfrecords'
with tf.Session() as sess:
    with tf.io.TFRecordWriter(record_file) as writer:
        for filename, label in image_labels.items():
            img_raw = tf.gfile.FastGFile(filename, 'rb').read()
            tf_example = image_example_sess(img_raw, label, sess)
            writer.write(tf_example.SerializeToString())

#####################################################################################################
# 读取 TFRecord 文件
input_files = ['images.tfrecords']  # 可以有多个文件
raw_image_dataset = tf.data.TFRecordDataset(input_files)

def _parse_image_function(example_proto):
    # Create a dictionary describing the features.
    image_feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),  # height,width,depth只有一个数字,因此[]中可以不写
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
         # 此处label只有1个数字,[]中可以不写,但如果是检测标签会有4个数字(和写tfrecord时一致),[]中就必须写4了,否则无法解析(报错:Can't parse serialized Example.)
        'label': tf.io.FixedLenFeature([], tf.int64), 
        'image_raw': tf.io.FixedLenFeature([], tf.string),
        }
    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
iterator = parsed_image_dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
    for i in range(len(image_labels)):
        feature_dict_val = sess.run(feature_dict)
        print('height: ', feature_dict_val['height'])
        print('width: ', feature_dict_val['width'])
        print('depth: ', feature_dict_val['depth'])
        print('label: ', feature_dict_val['label'])
        img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
        plt.imshow(img)
        plt.show()

#####################################################################################################
# 读取 TFRecord 文件,文件路径由placeholder提供
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 定义遍历dataset的initializable_iterator()
iterator = dataset.make_initializable_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer,feed_dict={input_files : ['images.tfrecords', 'images.tfrecords']})
    # 遍历所有数据一个epoch,遍历结束时抛出OutOfRangeError,因为在动态指定输入数据时不同数据来源的数据量大小未知,
    # 该方法使得不必提前知道数据量的确切大小
    while True:
        try:
            feature_dict_val = sess.run(feature_dict)
            print('height: ', feature_dict_val['height'])
            print('width: ', feature_dict_val['width'])
            print('depth: ', feature_dict_val['depth'])
            print('label: ', feature_dict_val['label'])
            img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
            plt.imshow(img)
            plt.show()
        except tf.errors.OutOfRangeError:
            break

#####################################################################################################
input_files = ['images.tfrecords']  # 可以有多个文件
dataset = tf.data.TFRecordDataset(input_files)

dataset = dataset.map(_parse_image_function).shuffle(10).batch(10)
dataset = dataset.repeat(5)
iterator = dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
    while True:
        try:
            fig = plt.figure()
            ax1 = fig.add_subplot(251)
            ax2 = fig.add_subplot(252)
            ax3 = fig.add_subplot(253)
            ax4 = fig.add_subplot(254)
            ax5 = fig.add_subplot(255)
            ax6 = fig.add_subplot(256)
            ax7 = fig.add_subplot(257)
            ax8 = fig.add_subplot(258)
            ax9 = fig.add_subplot(259)
            ax10 = fig.add_subplot(2,5,10)

            feature_dict_val = sess.run(feature_dict)
            print('height: ', feature_dict_val['height'])
            print('width: ', feature_dict_val['width'])
            print('depth: ', feature_dict_val['depth'])
            print('label: ', feature_dict_val['label'])
            img1 = tf.io.decode_image(feature_dict_val['image_raw'][0]).eval()
            img2 = tf.io.decode_image(feature_dict_val['image_raw'][1]).eval()
            img3 = tf.io.decode_image(feature_dict_val['image_raw'][2]).eval()
            img4 = tf.io.decode_image(feature_dict_val['image_raw'][3]).eval()
            img5 = tf.io.decode_image(feature_dict_val['image_raw'][4]).eval()
            img6 = tf.io.decode_image(feature_dict_val['image_raw'][5]).eval()
            img7 = tf.io.decode_image(feature_dict_val['image_raw'][6]).eval()
            img8 = tf.io.decode_image(feature_dict_val['image_raw'][7]).eval()
            img9 = tf.io.decode_image(feature_dict_val['image_raw'][8]).eval()
            img10 = tf.io.decode_image(feature_dict_val['image_raw'][9]).eval()
            ax1.imshow(img1)
            ax2.imshow(img2)
            ax3.imshow(img3)
            ax4.imshow(img4)
            ax5.imshow(img5)
            ax6.imshow(img6)
            ax7.imshow(img7)
            ax8.imshow(img8)
            ax9.imshow(img9)
            ax10.imshow(img10)                                     
            plt.show()
        except tf.errors.OutOfRangeError:
            break

你可能感兴趣的:(Tensorflow,tfrecord)