TFRecords文件的创建与读取

一 TFRecords文件

TFRecords文件是TensorFlow专门的文件存储及读取格式,其中包含了tf.train.Example 协议内存块(protocol buffer),存储特征值与数据内容。通过tf.python_io.TFRecordWrite类,可以获取相应的数据并将其填入到Example协议内存块中,最终生成TFRecords文件。简单地说,tf.train.Example有若干数据特征(Features),而Features又有若干Feature字典,其中Feature只接受FloatList, ByteList, Int64List三种数据格式。TFRecords文件就是通过一个包含着二进制文件的数据文件,将特征与标签进行保存以便于Tensorflow读取。

TFRecords文件的创建与读取_第1张图片

 

二 案例分析

本例实现对daisy,dandelion,rose进行分类,项目结构如下:

TFRecords文件的创建与读取_第2张图片

其中,Data文件夹下有daisy,dandelion,rose三类植物,每类四张JPG格式图片,TFRecords_Writer.py负责创建TFRecords文件,TFRecords_Reader.py负责读取TFRecords文件。

1. TFRecords_Writer.py

import os
import tensorflow as tf
from PIL import Image

path = "Data"
dirnames = os.listdir(path)
writer = tf.python_io.TFRecordWriter("train.tfrecords")

for name in dirnames:
    class_path = path + os.sep + name
    for img_name in os.listdir(class_path):
        img_path = class_path + os.sep + img_name
        img = Image.open(img_path)
        img = img.resize((500, 500))
        img_raw = img.tobytes()    # 将图片转化成二进制形式
        example = tf.train.Example(
            features=(tf.train.Features(
                feature={
                    'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name.encode()])),
                    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }
            ))
        )
        writer.write(example.SerializeToString())

2. TFRecords_Reader.py

import tensorflow as tf
import cv2


def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.string),
            'image': tf.FixedLenFeature([], tf.string),
        })

    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img, [500, 500, 3])

    img = tf.cast(img, tf.float32) * (1. / 128) - 0.5
    label = tf.cast(features['label'], tf.string)

    return img, label


filename = "train.tfrecords"
img, label = read_and_decode(filename)
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1, capacity=10, min_after_dequeue=1)


sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)

for _ in range(10):
    # 同时取图片及标签,否则图片与标签无法对应
    val, label = sess.run([img_batch, label_batch])
    val.resize((500, 500, 3))
    cv2.imshow("cool", val)
    cv2.waitKey()
    print(label)


注:读取数据的格式必须与写入TFRecords文件的数据格式一致。

 

 

你可能感兴趣的:(DeepLearning)