tensorflow之tfrecords--一种数据格式

d# 一、tfrecords是什么
tfrecords是一种二进制编码的文件格式,tensorflow专用。
能将任意数据转换为tfrecords。
更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。

二、使用

1、 将数据保存为tfrecords

TFRecords文件包含了tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。

流程:
1. 将数据填入example protocol buffer
2. 将protocol buffer序列化为一个字符串
3. 通过tf.python_io.TFRecordWriter将字符串写入TFRecords文件


# 将数据转化成对应的属性
def _bytes_feature(value):
    # 字符串列表类型作为feature的value
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
    # 整数列表类型作为feature的value
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

writer=tf.python_io.TFRecordWriter('filename')      # 初始化一个writer
example=tf.train.Example(
    features=tf.train.Features(
        feature={
            "name": _bytes_feature('xianyi'),
            "age": _int64_feature(23)
        }
    )
)                                                   # 填充数据,example填充features,features填充feature,feature填写key和value,value是tf.train.Feature(类型=tf.train.类型(value=[value]))
writer.write(example.SerializeToString())           # 将序列化的example写入文件

这个写入example的过程有点复杂。
Example,初始化为tf.train.Example()
包含字段features=tf.train.Features()
字段features包含一个或多个: feature={"key": tf.train.Feature()}
feature是基于key-value对的存储,key是字符串,其映射到的是value 包含3种数据类型:
1. BytesList: 字符串列表: tf.train.BytesList(value=[value])
2. FloatList: 浮点数列表tf.train.FloatList()
3. Int64List: 64位整数列表tf.train.Int64List()
对于图片的numpy数组,可以.tostring之后存到BytesList,可以tf.gfile.FastGFile读入成bytes存到BytesList,可以.flatten后存到FloatList

Example中有几个一致性规则需要注意:
1. 如果一个examplefeature K的数据类型是T,那么所有其他的所有feature K都应该是这个数据类型
2. feature Kvalue listitem个数可能在不同的example中是不一样多的,这个取决于你的需求
3. 如果在一个example中没有feature k,那么如果在解析的时候指定一个默认值的话,那么将会返回一个默认值
4. 如果一个feature k 不包含任何的value值,那么将会返回一个空的tensor而不是默认值

一个example的例子:
features {
    feature {
        key: "age"
        value { float_list {
            value: 29.0
       }}
     }
    feature {
        key: "movie"
        value { bytes_list {
            value: "The Shawshank Redemption"
            value: "Fight Club"
       }}
     }
    feature {
        key: "movie_ratings"
        value { float_list {
            value: 9.0
            value: 9.7
       }}
     }
    feature {
        key: "suggestion"
        value { bytes_list {
            value: "Inception"
       }}
     }

看一段官网代码tensorflow/tensorflow/examples/how_tos/reading_data/convert_to_records.py

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(data_set, name):
  images = data_set.images
  labels = data_set.labels
  num_examples = data_set.num_examples

  rows = images.shape[1]
  cols = images.shape[2]
  depth = images.shape[3]

  filename = os.path.join(FLAGS.directory, name + '.tfrecords')
  print('Writing', filename)
  with tf.python_io.TFRecordWriter(filename) as writer:
    for index in range(num_examples):  # example是按照行读的
      image_raw = images[index].tostring()  # 将np数组转换为字符串存成Bytes_feature...,虽然这样比较大,后面介绍一种需要的存储空间不那么大的方法。
      example = tf.train.Example(
          features=tf.train.Features(
              feature={
                  'height': _int64_feature(rows),
                  'width': _int64_feature(cols),
                  'depth': _int64_feature(depth),
                  'label': _int64_feature(int(labels[index])),
                  'image_raw': _bytes_feature(image_raw)
              }))
      writer.write(example.SerializeToString())

例子将图片带标签转换为tfrecords:

from imageio import imread
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def img2feature(path):
    # 数组转化为string的第一种方法
    # np.array([1.,2.]).tostring()-->b'\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@'
    img = imread(path)
    feature = tf.train.Features(feature={
        'img':_bytes_feature(img.tostring())
    })
    return feature
def img2feature2(img_path):
    # 数组转化为string的第二种方法
    # tf.gfile.FastGFile(img,mode='rb).read()-->b'\x00H\x9c\x01I\x9e\x02J\x9f\x03J\xa4\x04K\xa7\x05K\xaa\x05K'另一种编码
    with tf.gfile.GFile(img_path, 'rb') as fid:
        encoded_jpg = fid.read()
    feature = tf.train.Features(feature={
        'img': _bytes_feature(encoded_jpg)
    })
    return feature
def convert_to_tfrecord():
    input_dir = './data'
    save_path = './data2.tfrecord'
    file_list = get_file_list(input_dir)

    writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
    for file in file_list:
        img_path = os.path.join(input_dir, file)
        print (img_path)
        feature = img_to_feature_1(img_path)
        example = tf.train.Example(features=feature)
        writer.write(example.SerializeToString())
    writer.close()

2、 读取tfrecords

tf.parse_single_example解码,tf.TFRecordReader读取
一般,为了高效的读取数据,tf中使用队列读取数据

def read_and_decode(filename):
    # 生成一个文件名的队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()  # 定义一个reader
    _, serialized_example = reader.read(filename_queue)   # 读取文件名和example

    # 还原feature, 和制作tfrecords时一样
    feature = { 'label': tf.FixedLenFeature([], tf.int64),  # 对于单个元素的变量,我们使用FixlenFeature来读取,需要指明变量存储的数据类型;对于list类型的变量,我们使用VarLenFeature来读取,同样需要指明读取变量的类型
                'img_raw' : tf.FixedLenFeature([], tf.string), }
    # 使用tf.parse_single_example来解析example
    features = tf.parse_single_example(serialized_example, features=feature)

    # 对于图像,使用tf.decode_raw解析对应的features,指定类型,然后reshape等
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return img, label

img, label = read_and_decode('train.tfrecords')
# 在训练时使用shuffle_batch随机打乱顺序,并生成batch
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=30, 
                                                capacity=2000,  # 队列的最大容量
                                                num_threads=1,  # 进行队列操作的线程数
                                                min_after_dequeue=1000) # dequeue后最小的队列大小,used to ensure a level of mixing of elements.

# tf队列也需要初始化在sess中才能执行                      
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()  # 创建一个coordinate,用于协调各线程
    threads = tf.train.start_queue_runners(coord=coord)  # 使用QueueRunner对象来提取数据

    try:  # 推荐代码
        while not coord.should_stop():
            # Run training steps or whatever
            sess.run(train_op)
    except tf.errors.OutOfRangeError:
        print 'Done training -- epoch limit reached'
    finally:
        # When done, ask the threads to stop.关闭线程
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)

制作tfrecords和读取的完整代码如下:

import tensorflow as tf
import numpy as np
import os



#=============================================================================#
# write images and label in tfrecord file and read them out
def encode_to_tfrecords(tfrecords_filename, data_num): 
    ''' write into tfrecord file '''
    if os.path.exists(tfrecords_filename):
        os.remove(tfrecords_filename)

    writer = tf.python_io.TFRecordWriter('./'+tfrecords_filename) # 创建.tfrecord文件,准备写入

    for i in range(data_num):
        img_raw = np.random.randint(0,255,size=(56,56))
        img_raw = img_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
        writer.write(example.SerializeToString()) 

    writer.close()
    return 0

def decode_from_tfrecords(filename_queue, is_batch):

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })  #取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'],tf.int64)
    image = tf.reshape(image, [56,56])
    label = tf.cast(features['label'], tf.int64)

    if is_batch:
        batch_size = 3
        min_after_dequeue = 10
        capacity = min_after_dequeue+3*batch_size
        image, label = tf.train.shuffle_batch([image, label],
                                              batch_size=batch_size, 
                                              num_threads=3, 
                                              capacity=capacity,
                                              min_after_dequeue=min_after_dequeue)
    return image, label

#=============================================================================#

if __name__=='__main__':
    # make train.tfrecord
    train_filename = "train.tfrecords"
    encode_to_tfrecords(train_filename,100)
##    # make test.tfrecord
    test_filename = 'test.tfrecords'
    encode_to_tfrecords(test_filename,10)

#    run_test = True
    filename_queue = tf.train.string_input_producer([train_filename],num_epochs=None) #读入流中
    train_image, train_label = decode_from_tfrecords(filename_queue, is_batch=True)

    filename_queue = tf.train.string_input_producer([test_filename],num_epochs=None) #读入流中
    test_image, test_label = decode_from_tfrecords(filename_queue, is_batch=True)
    with tf.Session() as sess: #开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord=tf.train.Coordinator()
        threads= tf.train.start_queue_runners(coord=coord)

        try:
            # while not coord.should_stop():
            for i in range(2):
                example, l = sess.run([train_image,train_label])#在会话中取出image和label
                print('train:')
                print(example, l) 
                texample, tl = sess.run([test_image, test_label])
                print('test:')
                print(texample,tl)
        except tf.errors.OutOfRangeError:
            print('Done reading')
        finally:
            coord.request_stop()

        coord.request_stop()
        coord.join(threads)

参考
这里写链接内容
这里写链接内容
这里写链接内容
这里写链接内容
tensorflow官方给出了一个新的利用tf.data.TFRecordDataset读取tfrecords的:
这里写链接内容

三、

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

tf.train.Example定义数据格式:

The recommended format for TensorFlow is a TFRecords file containing tf.train.Example protocol buffers (which contain Features as a field).

tf.python_io.TFRecordWriter写入tfrecords:

You write a little program that gets your data, stuffs it in an Example protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the tf.python_io.TFRecordWriterclass.
示例:convert_to_records.py

tf.parse_single_example解码,tf.TFRecordReader读取

To read a file of TFRecords, use tf.TFRecordReaderwith thetf.parse_single_exampledecoder.
示例:fully_connected_reader.py

你可能感兴趣的:(tensorflow)