tfrecord读写数据

tfrecord 的读写数据是真的麻烦,各种不方便,而且还有些坑,不太想讲这个东西,所以这里就打算写个简单的读写模板,可以作为参考。

其实写tfrecord本质只有三个类型: bytes,int64,float。所以我们要保存的数据就转成这三种类型就行了。

另外,这几种类型的数据都是一个list的形式,并且不支持多维数组,如果想要的数据是多维的,那就要转成1维,再读进来以后再转回去。

还有一点是变长数据,这种一般会padding到定长数组,然后再保存,不过这会浪费一些空间,tensorflow其实还是支持变长数据的,但是和定长数据不能一起用……也不知道这是怎么搞的,我觉得理论上应该是可以支持的吧。

写文件:

# coding: utf-8

import os, sys
import time, io
import tensorflow as tf

from PIL import Image
import numpy as np
import random

def _bytes_feature(value):  
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):  
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):  
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_data_item():
    n = random.randint(1, 10)
    fvalues = []
    for i in range(n):
        fvalues.append(random.random())
    return {
        "image_path": 't.jpg',
        'float_values': [1.0, 2.0, 3.0], # 同样支持numpy 格式:np.array([1.0, 2.0, 3.0])
        'float_values2': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape(-1), # 不支持多维数组,reshape成1维,读成tensor以后再reshape回去
        'var_values': np.array(fvalues),
    }

def get_example(data_item):
    # 1. 图片类型数据, 这里读取的方式有多种,最终读个二进制格式的就行, img_width, img_height根据需要可以选用
    with tf.gfile.GFile(data_item['image_path'], 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    img = Image.open(encoded_jpg_io)
    img_width, img_height = img.size

    # 先读图片,然后通过BytesIO转一下
    # img = Image.open(data_item['image_path']).convert('RGB')
    # img_width, img_height = img.size
    # f = io.BytesIO()
    # img.save(f, 'JPEG')
    # f.seek(0)
    # encoded_jpg = f.read()

    # 2. 文本类型数据
    image_name = os.path.basename(data_item['image_path'])[:random.randint(1,5)]


    example = tf.train.Example(features=tf.train.Features(
      feature={
        'image/encoded': _bytes_feature(encoded_jpg),
        'image/format': _bytes_feature(b"JPEG"),
        'image/width': _int64_feature([img_width]),
        'image/height': _int64_feature([img_height]),
        'image/name': _bytes_feature(image_name.encode('utf-8')),
        'float_values': _float_feature(data_item['float_values']),
        'float_values2': _float_feature(data_item['float_values2']),
        # 'var_values': _float_feature(data_item['var_values']),
      }
    ))
    return example


def run():
    save_path = 'train.tfrecord'
    writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)

    for i in range(100):
        data_item = get_data_item()
        example = get_example(data_item)
        writer.write(example.SerializeToString())
    writer.close()

run()

读文件:

# coding: utf-8

import tensorflow as tf
import os
slim_example_decoder = tf.contrib.slim.tfexample_decoder

def read_single_example_and_decode(file_name_list):
    file_name_queue = tf.train.string_input_producer(file_name_list)

    tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    reader = tf.TFRecordReader(options=tfrecord_options)

    _, serialized_example = reader.read(file_name_queue)


    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/width': tf.FixedLenFeature((), tf.int64, 1),
        'image/height': tf.FixedLenFeature((), tf.int64, 1),
        'image/name': tf.FixedLenFeature((), tf.string, default_value=''),
        'float_values': tf.FixedLenFeature([3], tf.float32),
        'float_values2': tf.FixedLenFeature([6], tf.float32),
        # 'var_values': tf.VarLenFeature(tf.float32),
    }
    items_to_handlers = {
        'image': slim_example_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
        'img_width': (slim_example_decoder.Tensor('image/width')),
        'img_height': (slim_example_decoder.Tensor('image/height')),
        'img_name': (slim_example_decoder.Tensor('image/name')),
        'float_values': (slim_example_decoder.Tensor('float_values')),
        'float_values2': (slim_example_decoder.Tensor('float_values2')),
        # 'var_values': (slim_example_decoder.Tensor('var_values')),
    }
    # tf.decode_raw

    serialized_example = tf.reshape(serialized_example, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,items_to_handlers)
    keys = decoder.list_items()
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))

    tensor_dict = decode_examples(tensor_dict)

    return tensor_dict

def decode_examples(tensor_dict):
    tensor_dict['float_values2'] = tf.reshape(tensor_dict['float_values2'], (2, 3))
    # 一些预处理也可以在这个阶段做,比如图片的增广和归一化等
    return tensor_dict

def get_batch_data(file_name_list, batch_size):
    tensor_dict = read_single_example_and_decode(file_name_list)

    batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2 = tf.train.batch(
            [tensor_dict['image'], tensor_dict['img_width'], tensor_dict['img_height'], 
             tensor_dict['img_name'], tensor_dict['float_values'], tensor_dict['float_values2']],
            batch_size=batch_size,
            capacity=batch_size*10,
            num_threads=16,
            dynamic_pad=True)
    return batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2


if __name__ == '__main__':
  batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2 = get_batch_data(['train.tfrecord'], batch_size=2)
  init_op = tf.global_variables_initializer()
  config = tf.ConfigProto()
  with tf.Session(config=config) as sess:
      sess.run(init_op)

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess, coord)

      res = sess.run([batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2])
      print(res)

      coord.request_stop()
      coord.join(threads)

使用tf.data读文件

tensorflow某个版本把tf.data这东西放了上来,我一开始还以为这个东西可以直接读文件不用写tfrecord了,后来发现不大行,还是照样写tfrecord,这tm不是脱裤子放屁?有多少场景会把数据都先读到内存里?有可能tf.data优化了结构和效率之类的,但是对于使用的人来说只不过多了一些学习成本,对于开发效率完全没有提升。真的是无语了……

代码的修改倒是不多,简化了一点点代码,体验一般吧。

# coding: utf-8

import tensorflow as tf
import os
slim_example_decoder = tf.contrib.slim.tfexample_decoder

def read_single_example_and_decode(serialized_example):
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/width': tf.FixedLenFeature((), tf.int64, 1),
        'image/height': tf.FixedLenFeature((), tf.int64, 1),
        'image/name': tf.FixedLenFeature((), tf.string, default_value=''),
        'float_values': tf.FixedLenFeature([3], tf.float32),
        'float_values2': tf.FixedLenFeature([6], tf.float32),
        # 'var_values': tf.VarLenFeature(tf.float32),
    }
    items_to_handlers = {
        'image': slim_example_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
        'img_width': (slim_example_decoder.Tensor('image/width')),
        'img_height': (slim_example_decoder.Tensor('image/height')),
        'img_name': (slim_example_decoder.Tensor('image/name')),
        'float_values': (slim_example_decoder.Tensor('float_values')),
        'float_values2': (slim_example_decoder.Tensor('float_values2')),
        # 'var_values': (slim_example_decoder.Tensor('var_values')),
    }
    # tf.decode_raw

    serialized_example = tf.reshape(serialized_example, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,items_to_handlers)
    keys = decoder.list_items()
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))

    tensor_dict = decode_examples(tensor_dict)

    return tensor_dict

def decode_examples(tensor_dict):
    tensor_dict['float_values2'] = tf.reshape(tensor_dict['float_values2'], (2, 3))
    return tensor_dict


if __name__ == '__main__':
    dataset = tf.data.TFRecordDataset(['train.tfrecord'], compression_type="ZLIB")
    dataset = dataset.map(read_single_example_and_decode)
    dataset = dataset.batch(10).repeat().shuffle(1000)

    iterator = dataset.make_one_shot_iterator()
    tensor_dict = iterator.get_next()

    init_op = tf.global_variables_initializer()
    config = tf.ConfigProto()
    with tf.Session(config=config) as sess:
        sess.run(init_op)

        for i in range(100):
          res = sess.run(tensor_dict)
          print(res['img_name'])

你可能感兴趣的:(人工智能深度学习)