tfrecord 的读写数据是真的麻烦,各种不方便,而且还有些坑,不太想讲这个东西,所以这里就打算写个简单的读写模板,可以作为参考。
其实写tfrecord本质只有三个类型: bytes,int64,float。所以我们要保存的数据就转成这三种类型就行了。
另外,这几种类型的数据都是一个list的形式,并且不支持多维数组,如果想要的数据是多维的,那就要转成1维,再读进来以后再转回去。
还有一点是变长数据,这种一般会padding到定长数组,然后再保存,不过这会浪费一些空间,tensorflow其实还是支持变长数据的,但是和定长数据不能一起用……也不知道这是怎么搞的,我觉得理论上应该是可以支持的吧。
# coding: utf-8
import os, sys
import time, io
import tensorflow as tf
from PIL import Image
import numpy as np
import random
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def get_data_item():
n = random.randint(1, 10)
fvalues = []
for i in range(n):
fvalues.append(random.random())
return {
"image_path": 't.jpg',
'float_values': [1.0, 2.0, 3.0], # 同样支持numpy 格式:np.array([1.0, 2.0, 3.0])
'float_values2': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape(-1), # 不支持多维数组,reshape成1维,读成tensor以后再reshape回去
'var_values': np.array(fvalues),
}
def get_example(data_item):
# 1. 图片类型数据, 这里读取的方式有多种,最终读个二进制格式的就行, img_width, img_height根据需要可以选用
with tf.gfile.GFile(data_item['image_path'], 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
img = Image.open(encoded_jpg_io)
img_width, img_height = img.size
# 先读图片,然后通过BytesIO转一下
# img = Image.open(data_item['image_path']).convert('RGB')
# img_width, img_height = img.size
# f = io.BytesIO()
# img.save(f, 'JPEG')
# f.seek(0)
# encoded_jpg = f.read()
# 2. 文本类型数据
image_name = os.path.basename(data_item['image_path'])[:random.randint(1,5)]
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded': _bytes_feature(encoded_jpg),
'image/format': _bytes_feature(b"JPEG"),
'image/width': _int64_feature([img_width]),
'image/height': _int64_feature([img_height]),
'image/name': _bytes_feature(image_name.encode('utf-8')),
'float_values': _float_feature(data_item['float_values']),
'float_values2': _float_feature(data_item['float_values2']),
# 'var_values': _float_feature(data_item['var_values']),
}
))
return example
def run():
save_path = 'train.tfrecord'
writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
for i in range(100):
data_item = get_data_item()
example = get_example(data_item)
writer.write(example.SerializeToString())
writer.close()
run()
# coding: utf-8
import tensorflow as tf
import os
slim_example_decoder = tf.contrib.slim.tfexample_decoder
def read_single_example_and_decode(file_name_list):
file_name_queue = tf.train.string_input_producer(file_name_list)
tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
reader = tf.TFRecordReader(options=tfrecord_options)
_, serialized_example = reader.read(file_name_queue)
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/width': tf.FixedLenFeature((), tf.int64, 1),
'image/height': tf.FixedLenFeature((), tf.int64, 1),
'image/name': tf.FixedLenFeature((), tf.string, default_value=''),
'float_values': tf.FixedLenFeature([3], tf.float32),
'float_values2': tf.FixedLenFeature([6], tf.float32),
# 'var_values': tf.VarLenFeature(tf.float32),
}
items_to_handlers = {
'image': slim_example_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
'img_width': (slim_example_decoder.Tensor('image/width')),
'img_height': (slim_example_decoder.Tensor('image/height')),
'img_name': (slim_example_decoder.Tensor('image/name')),
'float_values': (slim_example_decoder.Tensor('float_values')),
'float_values2': (slim_example_decoder.Tensor('float_values2')),
# 'var_values': (slim_example_decoder.Tensor('var_values')),
}
# tf.decode_raw
serialized_example = tf.reshape(serialized_example, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,items_to_handlers)
keys = decoder.list_items()
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict = decode_examples(tensor_dict)
return tensor_dict
def decode_examples(tensor_dict):
tensor_dict['float_values2'] = tf.reshape(tensor_dict['float_values2'], (2, 3))
# 一些预处理也可以在这个阶段做,比如图片的增广和归一化等
return tensor_dict
def get_batch_data(file_name_list, batch_size):
tensor_dict = read_single_example_and_decode(file_name_list)
batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2 = tf.train.batch(
[tensor_dict['image'], tensor_dict['img_width'], tensor_dict['img_height'],
tensor_dict['img_name'], tensor_dict['float_values'], tensor_dict['float_values2']],
batch_size=batch_size,
capacity=batch_size*10,
num_threads=16,
dynamic_pad=True)
return batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2
if __name__ == '__main__':
batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2 = get_batch_data(['train.tfrecord'], batch_size=2)
init_op = tf.global_variables_initializer()
config = tf.ConfigProto()
with tf.Session(config=config) as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
res = sess.run([batch_img, batch_w, batch_h, batch_img_name, batch_fvalues, batch_fvalues2])
print(res)
coord.request_stop()
coord.join(threads)
tensorflow某个版本把tf.data这东西放了上来,我一开始还以为这个东西可以直接读文件不用写tfrecord了,后来发现不大行,还是照样写tfrecord,这tm不是脱裤子放屁?有多少场景会把数据都先读到内存里?有可能tf.data优化了结构和效率之类的,但是对于使用的人来说只不过多了一些学习成本,对于开发效率完全没有提升。真的是无语了……
代码的修改倒是不多,简化了一点点代码,体验一般吧。
# coding: utf-8
import tensorflow as tf
import os
slim_example_decoder = tf.contrib.slim.tfexample_decoder
def read_single_example_and_decode(serialized_example):
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/width': tf.FixedLenFeature((), tf.int64, 1),
'image/height': tf.FixedLenFeature((), tf.int64, 1),
'image/name': tf.FixedLenFeature((), tf.string, default_value=''),
'float_values': tf.FixedLenFeature([3], tf.float32),
'float_values2': tf.FixedLenFeature([6], tf.float32),
# 'var_values': tf.VarLenFeature(tf.float32),
}
items_to_handlers = {
'image': slim_example_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
'img_width': (slim_example_decoder.Tensor('image/width')),
'img_height': (slim_example_decoder.Tensor('image/height')),
'img_name': (slim_example_decoder.Tensor('image/name')),
'float_values': (slim_example_decoder.Tensor('float_values')),
'float_values2': (slim_example_decoder.Tensor('float_values2')),
# 'var_values': (slim_example_decoder.Tensor('var_values')),
}
# tf.decode_raw
serialized_example = tf.reshape(serialized_example, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,items_to_handlers)
keys = decoder.list_items()
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
tensor_dict = decode_examples(tensor_dict)
return tensor_dict
def decode_examples(tensor_dict):
tensor_dict['float_values2'] = tf.reshape(tensor_dict['float_values2'], (2, 3))
return tensor_dict
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset(['train.tfrecord'], compression_type="ZLIB")
dataset = dataset.map(read_single_example_and_decode)
dataset = dataset.batch(10).repeat().shuffle(1000)
iterator = dataset.make_one_shot_iterator()
tensor_dict = iterator.get_next()
init_op = tf.global_variables_initializer()
config = tf.ConfigProto()
with tf.Session(config=config) as sess:
sess.run(init_op)
for i in range(100):
res = sess.run(tensor_dict)
print(res['img_name'])