TFRecord
和 tf.Example
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist = input_data.read_data_sets('./data', dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
size = images.shape[1]
num_examples = mnist.train.num_examples
# 输出TFRecord文件的地址
filename = './output.tfrecord'
# 创建writer来写tfrecords文件
writer = tf.io.TFRecordWriter(filename)
for i in range(num_examples):
# 将图像矩阵转换为一个字符串
image_raw = images[i].tostring()
#将一个样例转换为Example Protocol Buffer,并将所有的信息写入这个数据结构
example = tf.train.Example(features = tf.train.Features(feature={
'size': int64_feature(size),
'label': int64_feature(np.argmax(labels[i])),
'image_raw': bytes_feature(image_raw)
}))
#将一个Example写入TFRecord文件
writer.write(example.SerializeToString())
writer.close()
import tensorflow as tf
# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['./output.tfrecord'])
# 从文件中读取一个样例,也可以使用read_up_to函数一次性读取多个样例
_, serialized_example = reader.read(filename_queue)
#解析读入的一个样例,如果需要解析多个样例,可以使用parse_example函数
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'size': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)})
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
size = tf.cast(features['size'], tf.int32)
# 启动多线程处理数据
coord = tf.train.Coordinator()
with tf.Session() as sess:
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
# 每次运行可以读取TFRecord文件中的一个样例,当所有样例都读完之后,在此示例中程序会再重头读取
for i in range(10):
print(sess.run([image,label,size]))
import tensorflow as tf
import numpy as np
# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float(float32) / double(float64)."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def serialize_example(feature0, feature1, feature2, feature3):
""" Creates a tf.Example message ready to be written to a file. """
# Create a dictionary mapping the feature name to the tf.Example-compatible data type.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
# 输出:
'''
bytes_list {
value: "test_string"
}
bytes_list {
value: "test_bytes"
}
float_list {
value: 2.7182817459106445
}
int64_list {
value: 1
}
int64_list {
value: 1
}
'''
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
print(serialized_example) # b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'
example_proto = tf.train.Example.FromString(serialized_example)
print(example_proto)
# 输出:
'''
features {
feature {
key: "feature0"
value {
int64_list {
value: 0
}
}
}
feature {
key: "feature1"
value {
int64_list {
value: 4
}
}
}
feature {
key: "feature2"
value {
bytes_list {
value: "goat"
}
}
}
feature {
key: "feature3"
value {
float_list {
value: 0.9876000285148621
}
}
}
}
'''
本示例中处理的是单个整型、浮点类型、字节类型,因此value=[value]中对value
加了[]
使其具有可迭代性,如果需要存储的数据本身就有可迭代性则不能再加[]
,例如如果是要存储[1.1,1.2,1.3]
,则对应的函数_float_feature
应该写成:
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float(float32) / double(float64)."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# Create a dictionary with features that may be relevant.
# 高版本的TensorFlow支持解码后直接获取shape
def image_example(img_raw, label):
img_tensor = tf.image.decode_jpeg(img_raw)
image_shape = img_tensor.shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(img_raw),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
# Create a dictionary with features that may be relevant.
# 低版本的TensorFlow不支持解码后直接获取shape,转成numpy.ndarray后在获取
def image_example_sess(img_raw, label, sess):
img_tensor = tf.image.decode_jpeg(img_raw)
with sess.as_default():
img_data = img_tensor.eval()
print(type(img_data))
image_shape = img_data.shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(img_raw),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
'''
img_raw = tf.gfile.FastGFile('test1.jpg', 'rb').read()
label = 0
with tf.Session() as sess:
print(image_example_sess(img_raw, label, sess))
'''
#####################################################################################################
# 写入 TFRecord 文件
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
image_labels = {'test1.jpg' : 0, 'test2.jpg': 1}
record_file = 'images.tfrecords'
with tf.Session() as sess:
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
img_raw = tf.gfile.FastGFile(filename, 'rb').read()
tf_example = image_example_sess(img_raw, label, sess)
writer.write(tf_example.SerializeToString())
#####################################################################################################
# 读取 TFRecord 文件
input_files = ['images.tfrecords'] # 可以有多个文件
raw_image_dataset = tf.data.TFRecordDataset(input_files)
def _parse_image_function(example_proto):
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64), # height,width,depth只有一个数字,因此[]中可以不写
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
# 此处label只有1个数字,[]中可以不写,但如果是检测标签会有4个数字(和写tfrecord时一致),[]中就必须写4了,否则无法解析(报错:Can't parse serialized Example.)
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
iterator = parsed_image_dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
for i in range(len(image_labels)):
feature_dict_val = sess.run(feature_dict)
print('height: ', feature_dict_val['height'])
print('width: ', feature_dict_val['width'])
print('depth: ', feature_dict_val['depth'])
print('label: ', feature_dict_val['label'])
img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
plt.imshow(img)
plt.show()
#####################################################################################################
# 读取 TFRecord 文件,文件路径由placeholder提供
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 定义遍历dataset的initializable_iterator()
iterator = dataset.make_initializable_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer,feed_dict={input_files : ['images.tfrecords', 'images.tfrecords']})
# 遍历所有数据一个epoch,遍历结束时抛出OutOfRangeError,因为在动态指定输入数据时不同数据来源的数据量大小未知,
# 该方法使得不必提前知道数据量的确切大小
while True:
try:
feature_dict_val = sess.run(feature_dict)
print('height: ', feature_dict_val['height'])
print('width: ', feature_dict_val['width'])
print('depth: ', feature_dict_val['depth'])
print('label: ', feature_dict_val['label'])
img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
plt.imshow(img)
plt.show()
except tf.errors.OutOfRangeError:
break
#####################################################################################################
input_files = ['images.tfrecords'] # 可以有多个文件
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function).shuffle(10).batch(10)
dataset = dataset.repeat(5)
iterator = dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
while True:
try:
fig = plt.figure()
ax1 = fig.add_subplot(251)
ax2 = fig.add_subplot(252)
ax3 = fig.add_subplot(253)
ax4 = fig.add_subplot(254)
ax5 = fig.add_subplot(255)
ax6 = fig.add_subplot(256)
ax7 = fig.add_subplot(257)
ax8 = fig.add_subplot(258)
ax9 = fig.add_subplot(259)
ax10 = fig.add_subplot(2,5,10)
feature_dict_val = sess.run(feature_dict)
print('height: ', feature_dict_val['height'])
print('width: ', feature_dict_val['width'])
print('depth: ', feature_dict_val['depth'])
print('label: ', feature_dict_val['label'])
img1 = tf.io.decode_image(feature_dict_val['image_raw'][0]).eval()
img2 = tf.io.decode_image(feature_dict_val['image_raw'][1]).eval()
img3 = tf.io.decode_image(feature_dict_val['image_raw'][2]).eval()
img4 = tf.io.decode_image(feature_dict_val['image_raw'][3]).eval()
img5 = tf.io.decode_image(feature_dict_val['image_raw'][4]).eval()
img6 = tf.io.decode_image(feature_dict_val['image_raw'][5]).eval()
img7 = tf.io.decode_image(feature_dict_val['image_raw'][6]).eval()
img8 = tf.io.decode_image(feature_dict_val['image_raw'][7]).eval()
img9 = tf.io.decode_image(feature_dict_val['image_raw'][8]).eval()
img10 = tf.io.decode_image(feature_dict_val['image_raw'][9]).eval()
ax1.imshow(img1)
ax2.imshow(img2)
ax3.imshow(img3)
ax4.imshow(img4)
ax5.imshow(img5)
ax6.imshow(img6)
ax7.imshow(img7)
ax8.imshow(img8)
ax9.imshow(img9)
ax10.imshow(img10)
plt.show()
except tf.errors.OutOfRangeError:
break