TensorFlow使用TFRecord格式来统一存储数据,该格式可以将图像数据、标签信息、图像路径以及宽高等不同类型的信息放在一起进行统一存储,从而方便有效的管理不同的属性。
将训练数据集转成TFRecord
这里采用的数据集为目前正在做的项目的数据集,共包含两个目标文件夹(分别包含100幅图像)及对应的label.txt,label文件中的每一条内容分别对应两个文件夹中的一幅图像的路径及目标物的位置信息,即左上顶点和右下顶点的坐标信息(
根据读取图像数据方式的不同,共有两种方式将自己的数据集转换成TFRecord格式,同样对应两种方式对TFRecord格式进行解析。具体代码如下:
# Convert own_data to TFRecord of TF-Example protos.
import tensorflow as tf
from PIL import Image
import numpy as np
import os
# 生成整数型的属性
def int64_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
# 生成浮点型的属性
def float_feature(values):
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
# 生成字符串型的属性
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
# 标签信息的地址
dataset_dir = "/Users/**/**/label.txt"
# 图像存放的根目录地址
root_dir = '"/Users/**/**/'
# 输出TFRecord文件的地址
output_filename = "/Users/**/**/output.tfrecord"
file_lines = open(dataset_dir).readlines()
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(output_filename)
# 统计有效数据
valid_record_count = 0
# 从label.txt循环读入要写入的数据信息
for idx, line in enumerate(file_lines):
line = line.strip('\n')
image_target_path = line.split(",")[0]
image_search_path = line.split(",")[1]
image_labels_str = line.split(",")[2:]
image_format = str(image_target_path.split('.')[-1]).lower()
image_target_path = os.path.join(root_dir, image_target_path)
image_search_path = os.path.join(root_dir, image_search_path)
# 使用tf.gfile.FastGFile读取图像的原始数据,method_1
image_target_data = tf.gfile.FastGFile(image_target_path, 'r').read()
image_search_data = tf.gfile.FastGFile(image_search_path, 'r').read()
# 使用tf.image.decode_jpeg对图像进行解码,并利用img.eval().shape获得图像的宽高和通道信息
T_height, T_width, channels = tf.image.decode_jpeg(image_target_data).eval().shape
S_height, S_width, channels = tf.image.decode_jpeg(image_search_data).eval().shape
# 使用PIL的Image.open读取图像,method_2
image_target = Image.open(image_target_path, 'r')
image_target_data = image_target.tobytes()
T_height, T_width = image_target.size
image_search = Image.open(image_search_path, 'r')
image_search_data = image_search.tobytes()
S_height, S_width = image_search.size
image_labels = [float(x) for x in image_labels_str]
if not len(image_labels) == 4:
print("invalid label: " + line)
continue
# 将一个样例转化为Example Protocol Buffer,并将所有信息写入数据结构
example = tf.train.Example(features=tf.train.Features(feature={
'image_target/encoded': bytes_feature(image_target_data),
'image_search/encoded': bytes_feature(image_search_data),
'image_target/format': bytes_feature(image_format),
'image_search/format': bytes_feature(image_format),
'image/class/label': float_feature(image_labels),
'image_target/height': int64_feature(T_height),
'image_target/width': int64_feature(T_width),
'image_search/height': int64_feature(S_height),
'image_search/width': int64_feature(S_width),
'image/channels': int64_feature(channels),
'image_target/path': bytes_feature(image_target_path),
'image_search/path': bytes_feature(image_search_path) }))
# 将一个Example写入TFRecord文件
writer.write(example.SerializeToString())
valid_record_count += 1
writer.close()
print("\nvalid image count: " + str(valid_record_count))
读取TFRecord文件,具体代码如下:
# 使用 tf.image.decode_jpeg对jpg格式图像进行解码,对应tf.gfile读取图像,method_1
image_target = tf.image.decode_jpeg(features['image_target/encoded'])
# 使用tf.decode_raw将字符串解析成图像对应的像素数组,对应Image.open读取图像,method_2
image_target = tf.decode_raw(features['image_target/encoded'], tf.uint8)
label = features['image/class/label']
T_height = tf.cast(features['image_target/height'], tf.int32)
T_width = tf.cast(features['image_target/width'], tf.int32)
channels = tf.cast(features['image/channels'], tf.int32)
image_target_path = features['image_target/path']
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
# 每次运行可以读取TFRecord文件中的一个样例
for i in range(100):
image_t, label_info,t_height, t_width, channnel, path = sess.run([image_target,label,T_height, T_width,channels,image_target_path])
image_name = path.split("/")[-1].split(".")[0]
sample = sess.run(tf.reshape(image_t, [t_height, t_width, channnel]))
image= Image.fromarray(sample,'RGB')
# 以图像名称_label信息对图像命名,并进行存储
image.save(decode_path+ image_name+'_'+ str(label_info[0])+'.jpg')