Tensorflow 2.0 TFrecord的输出与读入

前言

最近新建了一个conda环境,搞上了tensorflow 2.0 (Beat),,,TF2.0改变确实很多,比如删除了Session……这对于我等习惯了先建图——再Session执行的人来说,我现在方的雅痞……2.0如何以图形式运行我还没有一点头绪(刚发现了tf.compat里面有历史版本233)……所以还在瑟瑟发抖的使用新版TF强烈推荐的keras。

今天正准备用TF2.0小跑一个图像任务,首先就是数据的读入,然而这边数据集11G,所以打算整合进TFrecord,方便之后;

介绍

TFrecord是Tensorflow提供并推荐使用的一种统一一种二进制文件格式,用于存储数据,理论上它可以保存任何格式的信息。

type value
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data

如上:整个文件由文件长度信息、长度校验码、数据、数据校验码组成。

TFRecord 的核心内容在于内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。
比如我这边使用的Example是这样的:

exam = tf.train.Example (
            features=tf.train.Features(
                feature={
                    'name' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[splits[-1].encode('utf-8')])),
                    'shape': tf.train.Feature(int64_list=tf.train.Int64List (value=[img.shape[0], img.shape[1], img.shape[2]])),
                    'data' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[bytes(img.numpy())]))
                }
            )
        )

可以看出,一个 Example 消息体包含了一个Features,而Features由诸多feature组成,其中每个feature 是一个 map,也就是 key-value 的键值对。其中,key 取值是 String 类型;而 value 是 Feature 类型的消息体,它的取值有 3 种:

  1. BytesList
  2. FloatList
  3. Int64List
    Tensorflow 2.0 TFrecord的输出与读入_第1张图片
    需要注意的是,他们都是列表的形式。

如何创建TFrecord文件

从上面我们知道,TFRecord 内由一系列Example组成,每个Example可以代表一组数据。

Tensorflow 2.0 Beat 中,输出TFrecord的API为tf.io.TFRecordWriter (filename, options=None), 其中第二个参数是用来控制文件的输出配置,一般不用管。第一个参数就是你要保存的文件名,调用该函数后,会返回一个Writer实例。

有了Writer,我们就可以不停的调用Writer.write (example)来把我们的Examples输出到文件中,需要注意的是,该函数接受的是一个string,所以我们应该先把example序列化为string类型,即Writer.write(example.SerializeToString())

当把所有的example输出到文件后,需要调用Writer.close()关闭文件。

例子:

writer = tf.io.TFRecordWriter (file_name)
for item in file_list:
    # item = .\\data\\xx(label)\\xxx.jpg
    splits = item.split ('\\')
    label = splits[2]
    img = tf.io.read_file (item)
    img = tf.image.decode_jpeg (img)
    exam = tf.train.Example (
    	features=tf.train.Features(
        	feature={
            	'name' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[splits[-1].encode('utf-8')])),
            	'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label)])),
            	'shape': tf.train.Feature(int64_list=tf.train.Int64List (value=[img.shape[0], img.shape[1], img.shape[2]])),
            	'data' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[bytes(img.numpy())]))
        	}
    	)
	)
    writer.write (exam.SerializeToString())
writer.close()

这里因为Tensorflow 2.0 默认使用的是Eager模式,所以img是一个 Eager Tensor,需要转为numpy。

如何读取TFrecord

老版本中,我们可以使用tf.TFrecordReader(),不过这个在2.0里我没找到,所以我们使用tf.data.TFRecordDataset(filename),调用后我们会得到一个Dataset(tf.data.Dataset),字面理解,这里面就存放着我们之前写入的所有Example。

还记得写入时,我们把每个example都进行了序列化么,所以我们要得到之前的example,还需要解析以下之前写入的序列化string。tf.io.parse_single_example(example_proto, feature_description)函数可以解析单条example.

解释一下这个函数:
第一个参数就是要解析的string,重点在于第二个参数,他要我们指定解析出来的example的格式。为了能正确解析,这个要和我们写入时的example对应起来:
比如我们写入时example为:

exam = tf.train.Example (
    features=tf.train.Features(
        feature={
            'name' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[splits[-1].encode('utf-8')])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label)])),
            'shape': tf.train.Feature(int64_list=tf.train.Int64List (value=[img.shape[0], img.shape[1], img.shape[2]])),
            'data' : tf.train.Feature(bytes_list=tf.train.BytesList (value=[bytes(img.numpy())]))
        }
    )
)

则我们需要指定的参数为:

feature_description = {
    'name' : tf.io.FixedLenFeature([], tf.string, default_value='Nan'),
    'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1), # 默认值自己定义
    'shape': tf.io.FixedLenFeature([3], tf.int64),
    'data' : tf.io.FixedLenFeature([], tf.string)
}

可以看到其中每一条都和之前的example中的feature对应(feature_description 中 map的key可以不对应,比如name改成id还是没问题的)。

OK,我们目前解决了解析一条example,但是一个Dataset中的example那么多。没关系tensorflow的dataset提供了Dataset.map(func),可以给定一个映射规则,将dataset中的所有条目按照该规则进行映射,其实和python的map函数差不多。

所以我们可以把我们的映射一条的函数呈递给Dataset.map(func),以解析所有的example。

reader = tf.data.TFRecordDataset(file_name) # 打开一个TFrecord

feature_description = {
    'name' : tf.io.FixedLenFeature([], tf.string, default_value='Nan'),
    'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1),
    'shape': tf.io.FixedLenFeature([3], tf.int64),
    'data' : tf.io.FixedLenFeature([], tf.string)
}
def _parse_function (exam_proto): # 映射函数,用于解析一条example
    return tf.io.parse_single_example (exam_proto, feature_description)
   
reader = reader.map (_parse_function)

读取的话,我们可以用for循环:

for row in reader.take(10): # 只取前10条
# for row in reader: # 枚举所有example
    print (row['name'])
    print (np.frombuffer(row['data'].numpy(), dtype=np.uint8)) # 如果要恢复成3d数组,可reshape

不过我们还可以完出花样:
dataset中还提供了很多方法,比如batch,shuffle,repeat。。。更多的可以自行去官网摸索(不知何时,访问TF官网突然就啥都不用了)

我们就可以这样:

reader = tf.data.TFRecordDataset(file_name)

feature_description = {
    'name' : tf.io.FixedLenFeature([], tf.string, default_value='Nan'),
    'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1),
    'shape': tf.io.FixedLenFeature([3], tf.int64),
    'data' : tf.io.FixedLenFeature([], tf.string)
}
def _parse_function (exam_proto):
    return tf.io.parse_single_example (exam_proto, feature_description)

reader = reader.repeat (1) # 读取数据的重复次数为:1次,这个相当于epoch
reader = reader.shuffle (buffer_size = 2000) # 在缓冲区中随机打乱数据
reader = reader.map (_parse_function) # 解析数据
batch  = reader.batch (batch_size = 10) # 每10条数据为一个batch,生成一个新的Dataset

shape = []
batch_data_x, batch_data_y = np.array([]), np.array([])
for item in batch.take(1): # 测试,只取1个batch
    shape = item['shape'][0].numpy()
    for data in item['data']: # 一个item就是一个batch
        img_data = np.frombuffer(data.numpy(), dtype=np.uint8)
        batch_data_x = np.append (batch_data_x, img_data)
    for label in item ['label']:
        batch_data_y = np.append (batch_data_y, label.numpy())

batch_data_x = batch_data_x.reshape ([-1, shape[0], shape[1], shape[2]])
print (batch_data_x.shape, batch_data_y.shape) # = (10, 480, 640, 3) (10,)
# 我的图片数据时480*640*3的

可以很方便的读取出数据的各批次,还能随即等等。

你可能感兴趣的:(Tensorflow)