【TensorFlow】将数据集保存为TFRecords格式的文件

目录

tf.python_io介绍

tf.train.Example介绍


TFRecords是TensorFlow官方推荐的存储数据的格式,方便对数据存储转移以及操作。

一个TFRecords文件包含一个带有CRC散列的字符串序列,其存储格式有如下四种:

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

tf.python_io介绍

                with tf.python_io.TFRecordWriter(output_filename) as tf_writer:
                    start_ndx =shard_id * num_per_shard
                    end_ndx = min((shard_id+1) * num_per_shard, len(file_names))
                    for i in range(start_ndx, end_ndx):

首先是定义输入文件的位置以及名称 output_filename。

通过 tf.python_io,将数据写入文件,用with打开的话,在写入结束之后后自动关闭文件,不用再特意写一个关闭文件的语句了,这个操作与python打开文件的操作是一样的。

                        example = data_utils.image_to_tfexample(image,  b'jpg', height,                     
                                                width, class_label)
                        tf_writer.write(example.SerializeToString())

最后在写入的时候有一个操作是image_to_tfexample,写入文件的话需要转化成对应的格式,然后在进行序列化操作。格式的转换操作就在函数image_to_tfexample中完成的。

tf.train.Example介绍

tf.train.Example 是一块 buffer即协议缓冲区,其中包含了各种feature并以字典的形式赋值。


def image_to_tfexample(image_data, image_format, height, width, class_label):
    features = {
        'image/encoded': bytes_feature(image_data),
        'image/format': bytes_feature(image_format),
        'image/class/label': int64_feature(class_label),
        'image/height': int64_feature(height),
        'image/width': int64_feature(width)
    }
    return tf.train.Example(features=tf.train.Features(feature=features))

 这个就是上面提到的格式转换函数。

def int64_feature(values):
    "Return a TF_feature of int64"

    if not isinstance(values, (tuple, list)):
        values = [values]

    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
    "Return a TF-feature of bytes"

    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

格式化原始数据可以使用tf.train.BytesList、tf.train.Int64List、tf.train.FloatList三个类。这三个类都有实例属性value用于我们将值传进去,一般tf.train.Int64List、tf.train.FloatList对应处理整数和浮点数,tf.train.BytesList用于处理二进制类型的数据。

需要注意的是:传进去的值都需要转换为列表,不然就会报错。
格式转换操作可以看成是一系列的操作:

tf.train.Example(features=tf.train.Features(feature=feature))
feature = {
'image/encoded':tf.train.Feature(int64list=tf.train.Int64List)
}

tf.train.Feature的属性显然对应于上面的类型,有int64list,bytelist等。

tf.train.Features是复数的Feature,它的属性是feature,可以理解为将上面的单个feature特征赋值过来,通过这个操作进行整合(以字典的方式传过来)。

tf.train.Example的属性是features,可以理解为对整合后的feature进行操作。

最后是通过tf.train.Example完成操作的。

还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。

这就是TFRecords文件的保存方式了。

 

 

 

参考链接:

https://blog.csdn.net/hfutdog/article/details/86244944#tftrainBytesList_5

 

你可能感兴趣的:(TensorFlow)