Tensorflow提供的TFRecord文件数据是通过tf.train.Example Protocol Buffer的格式存储的,数据格式:
message Example {
Features features = 1;
};
message Feature {
map feature =1;
};
message Feature {
oneof king {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
import tensorflow as tf
tf.reset_default_graph()
import os
from os.path import join
import matplotlib.pyplot as plt
def _parse_function(filename):
'''TFRecord 解析函数
参数:
filename: 图片名称
返回:
图像Tensor.
'''
image_bytes = tf.read_file(filename)
image_value = tf.image.decode_png(image_bytes, channels=3)
if image_value.dtype == tf.float32:
'''Image 图像数据转为unsigined integer,避免保存的图像尺寸改变'''
image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8)
return image_value
'''TFRecord数据结构类型转换.'''
def _bytes_feature(value):
'''字节类型'''
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
'''int64类型'''
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
'''图像路径.'''
image_path = "./CIFAR_images"
image_names = os.listdir(image_path)
'''文件名称'''
file_names = [join(image_path, f) for f in image_names]
'''Dataset建立文件名队列'''
filename_queue = tf.data.Dataset.from_tensor_slices(file_names)
'''通过文件名队列解析图像.'''
image_map = filename_queue.map(_parse_function)
'''遍历图像队列.'''
image_value = image_map.make_one_shot_iterator().get_next()
'''图像数量:100.'''
image_num = len(file_names)
'''存储图像值的list.'''
images = []
def image_info(images):
'''获取图像信息
参数:
images: 图像矩阵列表.
返回:
图像:长,宽和通道数.
'''
width, height, channels = images[0].shape
# print("image value: {}".format(images[0]))
print("width: {}, height: {}, channels: {}".format(width, height, channels))
def image_show(sess, image_num, image_value):
'''图像显示
参数:
sess: 会话.
image_num: 图像数量.
image_value: 图像队列.
'''
plt.figure(figsize=(10, 10))
for i in range(image_num):
'''image type: '''
# print("image type: {}".format(image_value.dtype))
if image_value.dtype != tf.float32:
'''
Convert image to float32 type, and value range in [0,1].
image type:
'''
image_value = tf.image.convert_image_dtype(image_value, dtype=tf.float32)
image_value = tf.image.resize_images(image_value, [28, 28], method=0)
if image_value.dtype == tf.float32:
'''Convert image to unsigned integer type, and value range [0, 255]'''
image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8)
image_values = sess.run(image_value)
'''Show images.'''
plt.subplot(10,10,i+1).set_title("fig{}".format(i+1))
plt.subplots_adjust(hspace=0.1, wspace=0.6)
plt.imshow(image_values)
plt.axis("off")
plt.show()
def process_image(sess, image_num, image_value):
'''图像处理
参数:
sess: 会话.
image_num: 图像数量.
image_value: 图像队列.
返回:
images: 图像矩阵列表.
'''
for i in range(image_num):
if image_value.dtype != tf.float32:
'''Convert image to float32 type, and value range in [0,1], Tensor'''
image_value = tf.image.convert_image_dtype(image_value, dtype=tf.float32)
image_value = tf.image.resize_images(image_value, [28, 28], method=0)
if image_value.dtype == tf.float32:
'''Convert image to unsigned integer type, and value range [0, 255], Tensor'''
image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8)
'''Image matrix.
[[[ 63 48 18]
[ 64 53 21]
[ 62 54 22]
...
[104 76 45]]]
'''
image_values = sess.run(image_value)
images.append(image_values)
return images
def save_tfrecord(images):
'''保存图像数据为TFRecord.
参数:
images: 图像矩阵列表.
'''
if not os.path.exists("outputs/"):
os.makedirs("outputs/")
file_name = "./outputs/cifar10.tfrecords"
writer = tf.python_io.TFRecordWriter(file_name)
for i in range(image_num):
image_raw = images[i].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':_bytes_feature(image_raw),
'image_num':_int64_feature(image_num),
'height':_int64_feature(28),
'width':_int64_feature(28)
}))
writer.write(example.SerializeToString())
writer.close()
print("Saved.")
'''执行保存.'''
with tf.Session() as sess:
images = process_image(sess, image_num, image_value)
print("image number: {}".format(len(images)))
'''width: 28, height: 28, channels: 3'''
image_info(images)
save_tfrecord(images)
Dataset方式读取
import tensorflow as tf
tf.reset_default_graph()
import matplotlib.pyplot as plt
def parse(record):
'''TFRecord文件解析函数.
参数:
record:
返回:
features["image_raw"]:图像数据.
features["image_num"]:图像数量.
features["height"]:图像高度.
features["width"]:图像宽度.
'''
features = tf.parse_single_example(
record,
features={"image_raw":tf.FixedLenFeature([], tf.string),
"image_num":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"width":tf.FixedLenFeature([], tf.int64),
}
)
image_raw = features["image_raw"]
image_num = features["image_num"]
height = features["height"]
width = features["width"]
return image_raw, image_num, height, width
'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
'''读取TFRecord文件'''
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''遍历初始化设置:当有变量时使用此方法.当没有变量是可使用:dataset.make_one_shot_iterator()'''
iterator = dataset.make_initializable_iterator()
'''获取TFRecord存储的数据,使用get_next遍历读取
数据类型:Tensor("IteratorGetNext:0", shape=(), dtype=string)
'''
images, num, height, width = iterator.get_next()
def iterator_data_subplot(sess, num, images, height, width):
'''可视化读取的图像
参数:
sess: 会话.
num: 显示图像数量.
images: 图像矩阵列表.
height:图像高度.
width: 图像宽度.
'''
plt.figure(figsize=(10, 10))
for i in range(num):
'''数据恢复:转换为uint8类型.
Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
'''
image = tf.decode_raw(images, tf.uint8)
'''图像值:
Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
'''
height = tf.cast(height, tf.int32)
width = tf.cast(width, tf.int32)
'''图像信息.
image shape: (49152,), height: 28, width: 28
返回的图形数据为一个列向量,行数为:H*W*C
需要重新整理为标准图像(h,w,c)
'''
image = tf.reshape(image, [height, width, 3])
'''Image matrix.
[[[165 170 176]
[161 167 173]
...
[139 148 155]]]
'''
image = sess.run(image)
'''Show image.'''
plt.subplot(10,10,i+1).set_title("fig{}".format(i+1))
plt.subplots_adjust(hspace=0.8)
plt.axis("off")
plt.imshow(image)
plt.show()
with tf.Session() as sess:
'''初始化 iterator.'''
sess.run(iterator.initializer)
iterator_data_subplot(sess, 100, images, height, width)
(1) 图像矩阵数据有两种取值:[0,1]float类型,[0,255]int类型,Tensorflow对图像处理(包括剪裁,变换尺寸等操作)需要将图像转换为float格式;
(2) 图像数据保存为TFRecord格式时,需要将图像数据调整为int类型,否则保存的图像尺寸会扩大一倍,如 32 × 32 32 \times 32 32×32的float类型数据,转换后的尺寸为 64 × 64 64 \times 64 64×64,图像复原后出现错误,CIFAR数据的处理如下图所示;
(3) 读取TFRecord数据有两种方法,一种是直接使用TFRecordReader读取(即将停止维护),另一种是使用Dataset读取,Tensorflow1.3版本后,官方推荐使用Dataset方法。
(4) 主程序中执行一次sess.run即运行了一次get.next。(如TFRecord中保存了图片的数量为100, 通过sess.run获取数量,在遍历数据集,会从第二个数据开始遍历,因为运行一次sess.run已经开始了一次遍历,所以注意数据集读取的完整性。)
[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/data/Dataset
[2]https://blog.csdn.net/fu6543210/article/details/80269215