(一)Tensorflow图像数据转化TFRecord数据格式

1 TFRecord数据格式

Tensorflow提供的TFRecord文件数据是通过tf.train.Example Protocol Buffer的格式存储的,数据格式:

message Example {
	Features features = 1;
};
message Feature {
	map feature =1;
};
message Feature {
	oneof king {
		BytesList bytes_list = 1;
		FloatList float_list = 2;
		Int64List int64_list = 3;
	}
};

2 数据转化TFRecord

import tensorflow as tf
tf.reset_default_graph()
import os
from os.path import join
import matplotlib.pyplot as plt

def _parse_function(filename):
	'''TFRecord 解析函数
	参数:
	filename: 图片名称
	返回:
	图像Tensor.
	'''
    image_bytes = tf.read_file(filename)
    image_value = tf.image.decode_png(image_bytes, channels=3)
    if image_value.dtype == tf.float32: 
    	'''Image 图像数据转为unsigined integer,避免保存的图像尺寸改变'''
    	image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8) 
    return image_value
'''TFRecord数据结构类型转换.'''
def _bytes_feature(value):
	'''字节类型'''
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
	'''int64类型'''
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

'''图像路径.'''
image_path = "./CIFAR_images"
image_names = os.listdir(image_path)
'''文件名称'''
file_names = [join(image_path, f) for f in image_names]
'''Dataset建立文件名队列'''
filename_queue = tf.data.Dataset.from_tensor_slices(file_names)
'''通过文件名队列解析图像.'''
image_map = filename_queue.map(_parse_function)
'''遍历图像队列.'''
image_value = image_map.make_one_shot_iterator().get_next()
'''图像数量:100.'''
image_num = len(file_names)
'''存储图像值的list.'''
images = []
def image_info(images):
	'''获取图像信息
	参数:
	images: 图像矩阵列表.
	返回:
	图像:长,宽和通道数.
	'''
    width, height, channels = images[0].shape
#     print("image value: {}".format(images[0]))
    print("width: {}, height: {}, channels: {}".format(width, height, channels))
    
def image_show(sess, image_num, image_value):
	'''图像显示
	参数:
	sess: 会话.
	image_num: 图像数量.
	image_value: 图像队列.
	'''
    plt.figure(figsize=(10, 10))
    for i in range(image_num):
        '''image type: '''
#         print("image type: {}".format(image_value.dtype))
        if image_value.dtype != tf.float32:
            '''
            Convert image to float32 type, and value range in [0,1].
            image type: 
            '''
            image_value = tf.image.convert_image_dtype(image_value, dtype=tf.float32)
        image_value = tf.image.resize_images(image_value, [28, 28], method=0)
        if image_value.dtype == tf.float32:
            '''Convert image to unsigned integer type, and value range [0, 255]'''
            image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8)
        image_values = sess.run(image_value)
        '''Show images.'''
        plt.subplot(10,10,i+1).set_title("fig{}".format(i+1))
        plt.subplots_adjust(hspace=0.1, wspace=0.6)
        plt.imshow(image_values)
        plt.axis("off")
    plt.show()
    
def process_image(sess, image_num, image_value):
	'''图像处理
	参数:
	sess: 会话.
	image_num: 图像数量.
	image_value: 图像队列.
	返回:
	images: 图像矩阵列表.
	'''
    for i in range(image_num):
        if image_value.dtype != tf.float32:
            '''Convert image to float32 type, and value range in [0,1], Tensor'''
            image_value = tf.image.convert_image_dtype(image_value, dtype=tf.float32)
        image_value = tf.image.resize_images(image_value, [28, 28], method=0)
        if image_value.dtype == tf.float32:
            '''Convert image to unsigned integer type, and value range [0, 255], Tensor'''
            image_value = tf.image.convert_image_dtype(image_value, dtype=tf.uint8)
            '''Image matrix.
            [[[ 63  48  18]
              [ 64  53  21]
              [ 62  54  22]
              ...
              [104  76  45]]]
            '''
        image_values = sess.run(image_value)
        images.append(image_values)
    return images
def save_tfrecord(images):
	'''保存图像数据为TFRecord.
	参数:
	images: 图像矩阵列表.
	'''
    if not os.path.exists("outputs/"):
        os.makedirs("outputs/")
    file_name = "./outputs/cifar10.tfrecords"
    writer = tf.python_io.TFRecordWriter(file_name)
    for i in range(image_num):
        image_raw = images[i].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw':_bytes_feature(image_raw),
            'image_num':_int64_feature(image_num),
            'height':_int64_feature(28),
            'width':_int64_feature(28)
        }))
        writer.write(example.SerializeToString())
    writer.close()
    print("Saved.")
    
'''执行保存.'''    
with tf.Session() as sess:
    images = process_image(sess, image_num, image_value)
    print("image number: {}".format(len(images)))
    '''width: 28, height: 28, channels: 3'''
    image_info(images)
    save_tfrecord(images)

3 读取TFRecord

Dataset方式读取

import tensorflow as tf
tf.reset_default_graph()
import matplotlib.pyplot as plt

def parse(record):
	'''TFRecord文件解析函数.
	参数:
	record: 
	返回:
	features["image_raw"]:图像数据.
	features["image_num"]:图像数量.
	features["height"]:图像高度.
	features["width"]:图像宽度.
	'''
    features = tf.parse_single_example(
        record,
        features={"image_raw":tf.FixedLenFeature([], tf.string),
                  "image_num":tf.FixedLenFeature([], tf.int64),
                  "height":tf.FixedLenFeature([], tf.int64),
                  "width":tf.FixedLenFeature([], tf.int64),
                 }
    )
    image_raw = features["image_raw"]
    image_num = features["image_num"]
    height = features["height"]
    width = features["width"]
    return image_raw, image_num, height, width

'''TFRecord文件路径.'''
input_files = ["./outputs/cifar10.tfrecords"]
'''读取TFRecord文件'''
dataset = tf.data.TFRecordDataset(input_files)
'''数据映射解析.'''
dataset = dataset.map(parse)
'''遍历初始化设置:当有变量时使用此方法.当没有变量是可使用:dataset.make_one_shot_iterator()'''
iterator = dataset.make_initializable_iterator()
'''获取TFRecord存储的数据,使用get_next遍历读取
数据类型:Tensor("IteratorGetNext:0", shape=(), dtype=string)
'''
images, num, height, width = iterator.get_next()
def iterator_data_subplot(sess, num, images, height, width):
	'''可视化读取的图像
	参数:
	sess: 会话.
	num: 显示图像数量.
	images: 图像矩阵列表.
	height:图像高度.
	width: 图像宽度.
	'''
    plt.figure(figsize=(10, 10))
    for i in range(num):
        '''数据恢复:转换为uint8类型.
        Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
        '''
        image = tf.decode_raw(images, tf.uint8)
        '''图像值:
        Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
        '''
        height = tf.cast(height, tf.int32)
        width = tf.cast(width, tf.int32)
        '''图像信息.
        image shape: (49152,), height: 28, width: 28
        返回的图形数据为一个列向量,行数为:H*W*C
        需要重新整理为标准图像(h,w,c)
        '''
        image = tf.reshape(image, [height, width, 3])
        '''Image matrix.
        [[[165 170 176]
          [161 167 173]
          ...
          [139 148 155]]]
        '''
        image = sess.run(image)
        '''Show image.'''
        plt.subplot(10,10,i+1).set_title("fig{}".format(i+1))
        plt.subplots_adjust(hspace=0.8)
        plt.axis("off")
        plt.imshow(image)
    plt.show()  
with tf.Session() as sess:
    '''初始化 iterator.'''
    sess.run(iterator.initializer)    
    iterator_data_subplot(sess, 100, images, height, width)

(一)Tensorflow图像数据转化TFRecord数据格式_第1张图片

图2.1 读取结果

3 总结

(1) 图像矩阵数据有两种取值:[0,1]float类型,[0,255]int类型,Tensorflow对图像处理(包括剪裁,变换尺寸等操作)需要将图像转换为float格式;
(2) 图像数据保存为TFRecord格式时,需要将图像数据调整为int类型,否则保存的图像尺寸会扩大一倍,如 32 × 32 32 \times 32 32×32的float类型数据,转换后的尺寸为 64 × 64 64 \times 64 64×64,图像复原后出现错误,CIFAR数据的处理如下图所示;
(3) 读取TFRecord数据有两种方法,一种是直接使用TFRecordReader读取(即将停止维护),另一种是使用Dataset读取,Tensorflow1.3版本后,官方推荐使用Dataset方法。
(4) 主程序中执行一次sess.run即运行了一次get.next。(如TFRecord中保存了图片的数量为100, 通过sess.run获取数量,在遍历数据集,会从第二个数据开始遍历,因为运行一次sess.run已经开始了一次遍历,所以注意数据集读取的完整性。)


[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/data/Dataset
[2]https://blog.csdn.net/fu6543210/article/details/80269215


你可能感兴趣的:(Tensorflow)