tensorflow 使用TFRECRD保存数据并读取

TFRECORD 是什么

我们训练文件夹的内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

1.TFRECORD的组成

分别是TFRecord生成器以及样本Example模块。

1.1 生成器

writer = tf.python_io.TFRecordWriter(record_path)
writer.write(tf_example.SerializeToString())
writer.close()

这里面writer就是我们TFrecord生成器。接着我们就可以通过writer.write(tf_example.SerializeToString())来生成我们所要的tfrecord文件了。这里需要注意的是我们TFRecord生成器在写完文件后需要关闭writer.close()。这里tf_example.SerializeToString()是将Example中的map压缩为二进制文件,更好的节省空间。接下来讲述tf_example是如何生成。

1.2 Example模块

Example协议块

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

我们可以看出上面的tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))

(1)tf.train.Example(features = None) 这里的features是tf.train.Features类型的特征实例。
(2)tf.train.Features(feature = None) 这里的feature是以字典的形式存在,*key:要保存数据的名字 value:要保存的数据,但是格式必须符合tf.train.Feature实例要求。
以上参考:https://www.jianshu.com/p/b480e5fcb638

2. 生成TFRECORD文件

循环读取图片

import os
import tensorflow as tf
from PIL import Image

cwd = 'opt/pyproject/demo/TFRECORD/matlab//' 
classes = {'noarm','onearm','run','twoarms'}

def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path): #已经将四种不同类型的图片分在了四个文件夹内
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #将图片转化为原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()

3. 读取TFRECORD文件

3.1 使用dataset读取文件

def write_test():
    writer = tf.python_io.TFRecordWriter('test.tfrecord')
    image=Images.open(cwd+'noarm_1.jpeg')
    #image=image.resize([500,500])
    image_data=image.tobytes()
    index=0
    # 创建 Example 对象,并且将 Feature 一一对应填充进去。
    example = tf.train.Example(features=tf.train.Features(feature={
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'image_data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
      }))
    # 将 example 序列化成 string 类型,然后写入。
    writer.write(example.SerializeToString())
    writer.close()
write_test()

def _get_images_labels(input_file):
    dataset=tf.data.TFRecordDataset(input_file)
    dataset=dataset.map(_parse_record)
    #dataset=dataset.prefetch(-1)
    #dataset=dataset.repeat().batch(128)
    iterator=dataset.make_one_shot_iterator()
    images, labels=iterator.get_next()
    return images, labels
   
def _parse_record(example_proto):
    features = {
         'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
         'image_data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
         }
    parsed_features=tf.parse_single_example(example_proto, features=features)
    img=tf.decode_raw(parsed_features['image_data],out_type=uint8)
    #img=tf.reshape(img,shape=[500,500,3]) 
    # 如果前面使用了image.resize([500,500])这里就需要还原为[500,500,3]
    img=tf.reshape(img, shape=[656,875,3])
    label=parsed_features['label']
    #这里label不需要reshape
    label=tf.cast(label, tf.in32)
    
    return img, label 

with tf.Session() as sess:
   image, label = sess.run(_get_images_labels('test.tfrecord'))
   plt.figure()
   plt.imshow(image)
   plt.show()

参考:https://blog.csdn.net/briblue/article/details/80789608

3.2 使用队列读取文件

import os
import tensorflow as tf
from PIL import Image
 
cwd = 'E:/train_data/picture_dog//' 
classes = {'husky','jiwawa'}
 
 
#制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #将图片转化为原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()
#-------------------------------------------------------------------------
 
#读取二进制数据
 
def read_and_decode(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label
#--------------------------------------------------------------------------    
#---------主程序----------------------------------------------------------
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dog_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    
    with tf.Session() as sess: #开始一个会话  
        sess.run(init_op)  
        coord=tf.train.Coordinator()  
        threads= tf.train.start_queue_runners(coord=coord)  
        for i in range(40):  
            example, lab = sess.run(batch)#在会话中取出image和label  
            img=Image.fromarray(example, 'RGB')#这里Image是之前提到的  
            img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下图片;注意cwd后边加上‘/’  
            print(example, lab)  
        coord.request_stop()  
        coord.join(threads) 
        sess.close()

参考:https://blog.csdn.net/ywx1832990/article/details/78462582
整体参考:
读取并训练: https://my.oschina.net/u/3800567/blog/1637798?from=singlemessage
参考: https://www.cnblogs.com/puheng/p/9576521.html
tenforflow 官方文档:
https://tensorflow.google.cn/guide/datasets#parsing_tfexample_protocol_buffer_messages

你可能感兴趣的:(tensorflow)