上篇文章中我梳理一下在TensorFlow中几种不同类型数据读取的流程,但是没有具体说到TFRecords这种文件类型,这篇文章就来具体梳理这一文件格式。
TFRecords是TensorFlow中的设计的一种内置的文件格式,它是一种二进制文件,优点有如下几种:
在将其他数据存储为TFRecords文件的时候,需要经过两个步骤:
tf.python_io.TFRecordWriter(path)
注:此处的字符串为一个序列化的Example,通过Example.SerializeToString()
来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
上面这段代码即为Example协议块的规则,详解如下:
(1)tf.train.Example(features = None)
(2)tf.train.Features(feature = None)
(3)tf.train.Feature(**options)
options可以选择如下三种格式数据:
bytes_list = tf.train.BytesList(value = [Bytes])
int64_list = tf.train.Int64List(value = [Value])
float_list = tf.trian.FloatList(value = [Value])
(4)将图片数据转化为TFRecords的例子:
对每一个样本,都做如下的处理:
example = tf.train.Example(feature = tf.train.Features(feature = {
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)]))
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)]))
}))
和文件阅读器的流程基本相同,只是中间多了一步解析过程
(1)tf.parse_single_example(serialized,features=None,name= None
(2)tf.FixedLenFeature(shape,dtype)
(3)上面(1)中features中的value还可以为tf.VarLenFeature()
,但是这种方式用的比较少,它返回的是SparseTensor数据,这是一种只存储非零部分的数据格式,了解即可。
import tensorflow as tf
import numpy as np
import pandas as pd
train_frame = pd.read_csv("train.csv")
print(train_frame.head())
train_labels_frame = train_frame.pop(item="label")
train_values = train_frame.values
train_labels = train_labels_frame.values
print("values shape: ", train_values.shape)
print("labels shape:", train_labels.shape)
writer = tf.python_io.TFRecordWriter("csv_train.tfrecords")
for i in range(train_values.shape[0]):
image_raw = train_values[i].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd
def get_label_from_filename(filename):
return 1
filenames = tf.train.match_filenames_once('.\data\*.png')
writer = tf.python_io.TFRecordWriter('png_train.tfrecords')
for filename in filenames:
img=mpimg.imread(filename)
print("{} shape is {}".format(filename, img.shape))
img_raw = img.tostring()
label = get_label_from_filename(filename)
example = tf.train.Example(
features=tf.train.Features(
feature={
"image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
"""
读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
"""
#命令行参数
FLAGS = tf.app.flags.FLAGS #获取值
tf.app.flags.DEFINE_string("tfrecord_dir","./tmp/cifar10.tfrecords","写入图片数据文件的文件名")
#读取二进制转换文件
class CifarRead(object):
"""
读取二进制文件转换成张量,写进TFRecords,同时读取TFRcords
"""
def __init__(self,file_list):
"""
初始化图片参数
:param file_list:图片的路径名称列表
"""
#文件列表
self.file_list = file_list
#图片大小,二进制文件字节数
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
"""
解析二进制文件到张量
:return: 批处理的image,label张量
"""
#1.构造文件队列
file_queue = tf.train.string_input_producer(self.file_list)
#2.阅读器读取内容
reader = tf.FixedLengthRecordReader(self.bytes)
key ,value = reader.read(file_queue) #key为文件名,value为元组
print(value)
#3.进行解码,处理格式
label_image = tf.decode_raw(value,tf.uint8)
print(label_image)
#处理格式,image,label
#进行切片处理,标签值
#tf.cast()函数是转换数据格式,此处是将label二进制数据转换成int32格式
label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)
#处理图片数据
image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
print(image)
#处理图片的形状,提供给批处理
#因为image的形状已经固定,此处形状用动态形状来改变
image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
print(image_tensor)
#批处理图片数据
image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)
return image_batch,label_batch
def write_to_tfrecords(self,image_batch,label_batch):
"""
将文件写入到TFRecords文件中
:param image_batch:
:param label_batch:
:return:
"""
#建立TFRecords文件存储器
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecord_dir) #传进去命令行参数
#循环取出每个样本的值,构造example协议块
for i in range(10):
#取出图片的值, #写进去的是值,而不是tensor类型,
# 写入example需要bytes文件格式,将tensor转化为bytes用tostring()来转化
image = image_batch[i].eval().tostring()
#取出标签值,写入example中需要使用int形式,所以需要强制转换int
label = int(label_batch[i].eval()[0])
#构造每个样本的example协议块
example = tf.train.Example(features = tf.train.Features(feature = {
"image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
"label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
}))
#写进去序列化后的值
writer.write(example.SerializeToString()) #此处其实是将其压缩成一个二进制数据
writer.close()
return None
def read_from_tfrecords(self):
"""
从TFRecords文件当中读取图片数据(解析example)
:param self:
:return: image_batch,label_batch
"""
#1.构造文件队列
file_queue = tf.train.string_input_producer([FLAGS.tfrecord_dir]) #参数为文件名列表
#2.构造阅读器
reader = tf.TFRecordReader()
key,value = reader.read(file_queue)
#3.解析协议块,返回的值是字典
feature = tf.parse_single_example(value,features={
"image":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)
})
#feature["image"],feature["label"]
#处理标签数据 ,cast()只能在int和float之间进行转换
label = tf.cast(feature["label"],tf.int32) #将数据类型int64 转换为int32
#处理图片数据,由于是一个string,要进行解码, #将字节转换为数字向量表示,字节为一字符串类型的张量
#如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型
# decode_raw()可以将数据从string,bytes转换为int,float类型的
image = tf.decode_raw(feature["image"],tf.uint8)
#转换图片的形状,此处需要用动态形状进行转换
image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
#4.批处理
image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)
return image_batch,label_batch
if __name__ == '__main__':
# 找到文件路径,名字,构造路径+文件名的列表,"A.csv"...
# os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')
#加上路径
file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]
#初始化参数
cr = CifarRead(file_list)
#读取二进制文件
# image_batch,label_batch = cr.read_and_decode()
#从已经存储的TFRecords文件中解析出原始数据
image_batch, label_batch = cr.read_from_tfrecords()
with tf.Session() as sess:
#线程协调器
coord = tf.train.Coordinator()
#开启线程
threads = tf.train.start_queue_runners(sess,coord=coord)
print(sess.run([image_batch,label_batch]))
# print("存进TFRecords文件")
# cr.write_to_tfrecords(image_batch,label_batch)
# print("存进文件完毕")
#回收线程
coord.request_stop()
coord.join(threads)
注:
上段代码分为两个部分: