如何使用tfrecord?

目的

了解tfrecord文件格式,并学会如何write和read此文件。

tfrecord文件内部快速浏览

一、传统方法

如果你的每一条特征都是列表,且列表中包含着相同类型的值,例如

图像等。

 

  • 1.创建包含特征的列表,使用tf.train.BytesList,tf.train.FLoatList,tf.train.Int64List

  •  

字段features包含一个或多个: feature={"key": tf.train.Feature()} 
feature是基于key-value对的存储,key是字符串,其映射到的是value 包含3种数据类型: 
1. BytesList: 字符串列表: tf.train.BytesList(value=[value]) 
2. FloatList: 浮点数列表tf.train.FloatList() 
3. Int64List: 64位整数列表tf.train.Int64List() 
对于图片的numpy数组,可以.tostring之后存到BytesList,可以tf.gfile.FastGFile读入成bytes存到BytesList,可以.flatten后存到FloatList

原文:https://blog.csdn.net/weiweixiao3/article/details/82352062 
 

movie_name_list = tf.train.BytesList(value=[b'The Shawshank Redemption', b'Fight Club'])
movie_rating_list = tf.train.FloatList(value=[9.0, 9.7])
  • 2.使用tf.train.Feature创建包装后的列表,以便于tensorflow可以理解。

举例:

movie_names = tf.train.Feature(bytes_list=movie_name_list)
movie_ratings = tf.train.Feature(float_list=movie_rating_list)

3.以特征名称为键值,特征为对应值,创建字典。将字典赋予tf.train.Featuresfeature属性,创建tf.train.Features对象

movie_dict = {
  'Movie Names': movie_names,
  'Movie Ratings': movie_ratings
}
movies = tf.train.Features(feature=movie_dict)

4.使用tf.train.Exampletf.train.Features对象存储进tf.train.Examplefeatures属性中

example = tf.train.Example(features=movies)

可以使用tf.train.Example.FromString()来解析信息

 

example_proto = tf.train.Example.FromString(serialized_example)
example_proto

 

features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

5.将文件路径传给tf.python_io.TFRecordWriter,创建tf.python_io.TFRecordWriter对象writer。调用tf.train.Features对象的serializeToString方法,将结构化数据序列化。调用对象writer将序列化数据写入磁盘。

with tf.python_io.TFRecordWriter('customer_1.tfrecord') as writer:
    writer.write(example.SerializeToString())

如何读取TFRecords结构数据

  1. 创建tf.TFRecordReader对象reader

  2. 使用reader从.tfrecords文件中读取序列化数据serialized_sample

  3. 创建features字典,字典中包含着你想从tfrecord中读取的关键字以及对应值的类型,之后将features字典和序列化数据传入tf.parse_single_example()进行解析。得到包含期望数据的字典。

1)tf.parse_single_example(serialized,features=None,name= None

解析一个单一的Example原型
serialized : 标量字符串的Tensor,一个序列化的Example,文件经过文件阅读器之后的value
features :字典数据,key为读取的名字,value为FixedLenFeature
return : 一个键值对组成的字典,键为读取的名字
(2)tf.FixedLenFeature(shape,dtype)

shape : 输入数据的形状,一般不指定,为空列表
dtype : 输入数据类型,与存储进文件的类型要一致,类型只能是float32,int 64, string
return : Tensor (即使有零的部分也存储)

 https://blog.csdn.net/chengshuhao1991/article/details/78656724

# Read and print data:
sess = tf.InteractiveSession()

# Read TFRecord file
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(['customer_1.tfrecord'])

_, serialized_example = reader.read(filename_queue)

# Define features
read_features = {
    'Age': tf.FixedLenFeature([], dtype=tf.int64),
    'Movie': tf.VarLenFeature(dtype=tf.string),
    'Movie Ratings': tf.VarLenFeature(dtype=tf.float32),
    'Suggestion': tf.FixedLenFeature([], dtype=tf.string),
    'Suggestion Purchased': tf.FixedLenFeature([], dtype=tf.float32),
    'Purchase Price': tf.FixedLenFeature([], dtype=tf.float32)}

# Extract features from serialized data
read_data = tf.parse_single_example(serialized=serialized_example,
                                    features=read_features)

# Many tf.train functions use tf.train.QueueRunner,
# so we need to start it before we read
tf.train.start_queue_runners(sess)

# Print features
for name, tensor in read_data.items():
    print('{}: {}'.format(name, tensor.eval()))

二、使用tf.data读写tfrecord文件

在创建tfrecord数据特征feature0, feature1,feature2,feature3之后,使用tf.data读写数据:

创建dataset对象,

feature_dataset = tf.data.Dataset.from_tensor_slices(feature0, feature1, feature2, feature3)

# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
 
  

 使用tf.data.Dataset.map方法映射函数到Dataset的每一个元素。

map函数必须操作并返回tf.Tensors,一个非张量的函数例如 必须用tf.py_func包装。

serialized_features_dataset = features_dataset.map(serialize_example)

写入数据

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

读取数据

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

dataset包含序列化的tf.train.Example信息。当迭代结束,它返回字符串张量。

 

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))




'>

'>
\xe4?'>

\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02'>
# Create a description of the features.  
feature_description = {
    'feature0': tf.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.parse_single_example(example_proto, feature_description)
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset 

 

更多内容和细节参考链接:

1.https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564

2.https://www.tensorflow.org/tutorials/load_data/tf_records

3.https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/18_TFRecords_Dataset_API.ipynb

4.https://www.tensorflow.org/tutorials/load_data/tf_records

你可能感兴趣的:(tensorflow学习)