【Tensorflow】数据读取——tfrecords文件的读取与存储

tfrecords文件介绍

tfrecords文件是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,更方便复制和移动。

为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

文件格式:*.tfrecords

写入文件内容:Example协议块(是一种类字典的格式)

TFRecords存储

1、建立TFRecords存储器

  • tf.python_io.TFRecordWriter(path)
    • 写入tfrecords文件
    • path:TFRecords文件的路径
    • return:写文件
    • 方法:
    • write(record):向文件中写入一个字符串记录(一个Example)
    • close():关闭文件写入器

注:字符串为一个序列化的Example,Example.SerializeToString()

2、TFRecords存储

构造每一个样本的Example协议块

  • tf.train.Example(features=None)
    • 写入tfrecords文件
    • features:tf.train.Features类型的特征实例
    • return:example格式协议块
  • tf.train.Features(feature=None)
    • 构建每个样本的信息键值对
    • feature:字典数据,key为要保存的名字,value为tf.train.Feature实例
    • return:Features实例
  • tf.train.Feature(**options)
    • **options:例如
    • bytes_list=tf.train.BytesList(value=[Bytes])
    • int64_list=tf.train.Int64List(value=[Value])
    • 有下列三种格式:
    • tf.train.Int64List(value=[Value])
    • tf.train.BytesList(value=[Bytes])
    • tf.train.FloatList(value=[value])

TFRecords读取

说明:读取API详见之前章节

同文件阅读器流程一样,但是中间需要解析过程

解析TFRecords的example协议块

  • tf.parse_single_example(serialized, features=None, name=None)
    • 解析一个单一的Example原型
    • serialized:标量字符串Tensor,一个序列化的Example
    • features:dict字典数据,键为读取的名字,值为FixedLenFeature
    • return:一个键值对组成的字典,键为读取的名字
  • tf.FixedLenFeature(shape, dtype)
    • shape:输入数据的形状,一般不指定,为空列表
    • dtype:输入数据类型,与存储进文件的类型要一致,类型只能是float32,int64,string

完整代码

说明:以下代码为读取tfrecords文件,并解码。而读取二进制文件并存储为tfrecords文件的代码已被注释。

#! /usr/bin/env python 
# -*- coding:utf-8 -*-
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'   # 设置告警级别

# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir", "./cifar10/cifar-10-batches-py/", "文件目录")
tf.app.flags.DEFINE_string("cifar_tfrecords", "./cifar.tfrecords", "存储tfrecords文件")


class CifarRead(object):
    """
    完成读取二进制文件,写进tfrecords,读取tfrecords
    """
    def __init__(self, filelist):
        # 文件列表
        self.file_list = filelist

        # 定义读取图片的属性
        self.height = 32
        self.width = 32
        self.channel = 3

        # 二进制文件每张图片的字节
        self.lable_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.lable_bytes + self.image_bytes

    def read_and_decode(self):
        # 1、构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)

        # 2、构造二进制文件读取器,读取内容  参数:每个样本的字节数
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)

        # 3、解码内容
        label_image = tf.decode_raw(value, tf.uint8)
        # print(label_image)  # Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)

        # 4、分割出图片和标签数据
        label = tf.cast(tf.slice(label_image, [0], [self.lable_bytes]), tf.int32)
        image = tf.slice(label_image, [self.lable_bytes], [self.image_bytes])
        # print(label, image)  # Tensor("Cast:0", shape=(1,), dtype=int32) Tensor("Slice_1:0", shape=(3072,), dtype=uint8)

        # 5、可以对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        # print(label, image_reshape)  # Tensor("Cast:0", shape=(1,), dtype=int32) Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)

        # 6、批处理数据
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
        print(image_batch, label_batch)  # Tensor("batch:0", shape=(10, 32, 32, 3), dtype=uint8) Tensor("batch:1", shape=(10, 1), dtype=int32)

        return image_batch, label_batch


    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将图片的特征值和目标值存进tfrecords
        :param self:
        :param image_batch: 10张图片的特征值
        :param label_batch: 10张图片的目标值
        :return: None
        """
        # 1、建立tfrecord存储器
        writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)

        # 2、循环将所有样本写入文件,每张图片样本都要构造example协议
        for i in range(10):
            # 取出第i个图片数据的特征值和目标值    eval()出来是一个张量  label_batch[i].eval():[值]
            image= image_batch[i].eval().tostring()
            label = label_batch[i].eval()[0]

            # 构造一个样本的example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入单独的样本
            writer.write(example.SerializeToString())

        # 关闭
        writer.close()


    def read_from_tfrecords(self):
        # 1、构建文件队列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])

        # 2、构造文件阅读器,读取内容example
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)  # value:一个样本的序列化example

        # 3、解析example
        features = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        # print(features["image"], features["label"])
        # Tensor("ParseSingleExample/ParseSingleExample:0", shape=(), dtype=string) Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=int64)

        # 4、解码内容  如果读取的内容格式是string需要解码,如果是int64,float32则不需要解码
        image = tf.decode_raw(features["image"], tf.uint8)
        label = tf.cast(features["label"], tf.int32)   # 默认是int64,然后int64其实也是先用int32存储,所以这里直接类型转换为int32
        # print(image, label)
        # Tensor("DecodeRaw:0", shape=(?,), dtype=uint8) Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=int32)
        # 固定图片的形状,以方便批处理
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        # print(image_reshape, label)  # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8) Tensor("Cast:0", shape=(), dtype=int32)

        # 5、进行批处理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)

        return image_batch, label_batch


if __name__ == '__main__':
    # 读取二进制文件案例
    # 找到文件,放入列表
    file_name = os.listdir(FLAGS.cifar_dir)
    filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[0:10] == "data_batch"]

    cf = CifarRead(filelist)
    # 读取文件
    # image_batch, label_batch = cf.read_and_decode()
    # 读取tfrecords文件解码
    image_batch, label_batch = cf.read_from_tfrecords()

    # 开启会话运行结果
    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()

        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        # 存进tfrecords文件   eval()必须在session里使用
        # print("开始存储成tfrecords文件")
        # cf.write_to_tfrecords(image_batch, label_batch)
        # print("存储成tfrecords文件完成")

        # 打印读取的内容
        print(sess.run([image_batch, label_batch]))

        # 回收子线程
        coord.request_stop()
        coord.join(threads)

运行后如下:

【Tensorflow】数据读取——tfrecords文件的读取与存储_第1张图片

【Tensorflow】数据读取——tfrecords文件的读取与存储_第2张图片

你可能感兴趣的:(深度学习,深度学习,tensorflow,python)