Tensorflow2.0之TFRecords制作

TFRecords制作

  • 为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法
  • tf.Example类是一种将数据表示为{"string": value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据
  • 平台:Jupter Notebook
#首先导入相关的包
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import tensorflow as tf

通常情况下,tf.Example中可以使用以下几种格式:

  • tf.train.BytesList: 可以使用的类型包括 string和byte
  • tf.train.FloatList: 可以使用的类型包括 float和double
  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int6
#转换实例
def _bytes_feature(value):
    """Returns a bytes_list from a string/byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Return a float_list form a float/double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Return a int64_list from a bool/enum/int/uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


#测试输出
# tf.train.BytesList
print(_bytes_feature(b'test_string'))
print(_bytes_feature('test_string'.encode('utf8')))

# tf.train.FloatList
print(_float_feature(np.exp(1)))

# tf.train.Int64List
print(_int64_feature(True))
print(_int64_feature(1))

#输出结果
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_string"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

tfrecord制作方法

  • 创建tf.Example
#自定义数据测试

def serialize_example(feature0, feature1, feature2, feature3):
    """
    创建tf.Example
    """
    
    # 转换成相应类型
    feature = {
        'feature0': _int64_feature(feature0),
        'feature1': _int64_feature(feature1),
        'feature2': _bytes_feature(feature2),
        'feature3': _float_feature(feature3),
    }
    #使用tf.train.Example来创建
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    #SerializeToString方法转换为二进制字符串
    return example_proto.SerializeToString()


#测试数据生成

# 数据量
n_observations = int(1e4)

# Boolean feature
feature0 = np.random.choice([False, True], n_observations)

# Integer feature
feature1 = np.random.randint(0, 5, n_observations)

# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature
feature3 = np.random.randn(n_observations)


#测试数据写入

filename = 'tfrecord-1'

with tf.io.TFRecordWriter(filename) as writer:
    for i in range(n_observations):
        example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
        writer.write(example)

加载tfrecord文件

filenames = [filename]

# 读取
raw_dataset = tf.data.TFRecordDataset(filenames)

 

以上为模拟tfrecord数据制作方法

链接实战图像制作实战

 

你可能感兴趣的:(机器学习,python,tensorflow,机器学习)