代码承接上一篇
pprint.pprint(train_tfrecord_filenames)
pprint.pprint(valid_tfrecord_filenames)
pprint.pprint(test_tfrecord_fielnames)
首先打印一下我们所生成文件的文件名,下面是其中一个文件名。
‘generate_tfrecords_zip/test_00007-of-00020’
正如在基础API里面提到的,要想解析example,必须要定义解析每个field的字典
expected_features = {
"input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
"label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}
def parse_example(serialized_example):
example = tf.io.parse_single_example(serialized_example,
expected_features)
return example["input_features"], example["label"]
def tfrecords_reader_dataset(filenames, n_readers=5,
batch_size=32, n_parse_threads=5,
shuffle_buffer_size=10000):
dataset = tf.data.Dataset.list_files(filenames)
dataset = dataset.repeat()
dataset = dataset.interleave(
lambda filename: tf.data.TFRecordDataset(
filename, compression_type = "GZIP"),
cycle_length = n_readers
)
dataset.shuffle(shuffle_buffer_size)
dataset = dataset.map(parse_example,
num_parallel_calls=n_parse_threads)
dataset = dataset.batch(batch_size)
return dataset
tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(2):
print(x_batch)
print(y_batch)
首先定义了解析example的字典,机器就知道如何解析这个example。
然后要定义一个map函数,定义如何对每个样本进行处理:parse_example,参数意义是序列化之后的example。先把example解析出来,使用这个函数:tf.io.parse_single_example,参数是序列化的example和指定如何解析它的字典。解析完之后把输入特征和标签返回出去。
然后定义了一个完整的函数,完成从文件名列表到具体的dataset的转变。与读取csv文件的代码比较类似,就不再对读取的过程作讲解,运行之后可以看到,取出来的数据都是正常的。
接下来使用上面定义的函数读取生成训练中使用的数据集。
batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
test_tfrecord_fielnames, batch_size = batch_size)
读取了训练集、验证集和测试集,接下来在Keras中使用这些数据。
model = keras.models.Sequential([
keras.layers.Dense(30, activation='relu',
input_shape=[8]),
keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
patience=5, min_delta=1e-2)]
history = model.fit(tfrecords_train_set,
validation_data = tfrecords_valid_set,
steps_per_epoch = 11160 // batch_size,
validation_steps = 3870 // batch_size,
epochs = 100,
callbacks = callbacks)
训练完成后来用测试集评估一下模型效果。
model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)
到此就完成了tfrecord的实战。我们读取csv文件,转化成为tfrecord文件,再把tfrecord文件读取出来,形成一个数据集,再在tf.Keras中进行使用。需要记住的是:tfrecors是tensorflow独有的一种数据格式,在tf中有很多优化,在读取数据方面有独特的优势