代码示例:
import tensorflow as tf
import os
import pprint
import numpy as np
#读取csv文件将其转为tfrecord文件
source_dir = "./customize_generate_csv/"
print(os.listdir(source_dir))
[‘test_00.csv’, ‘test_01.csv’, ‘test_02.csv’, ‘test_03.csv’, ‘test_04.csv’, ‘test_05.csv’, ‘test_06.csv’, ‘test_07.csv’, ‘test_08.csv’, ‘test_09.csv’, ‘train_00.csv’, ‘train_01.csv’, ‘train_02.csv’, ‘train_03.csv’, ‘train_04.csv’, ‘train_05.csv’, ‘train_06.csv’, ‘train_07.csv’, ‘train_08.csv’, ‘train_09.csv’, ‘train_10.csv’, ‘train_11.csv’, ‘train_12.csv’, ‘train_13.csv’, ‘train_14.csv’, ‘train_15.csv’, ‘train_16.csv’, ‘train_17.csv’, ‘train_18.csv’, ‘train_19.csv’, ‘valid_00.csv’, ‘valid_01.csv’, ‘valid_02.csv’, ‘valid_03.csv’, ‘valid_04.csv’, ‘valid_05.csv’, ‘valid_06.csv’, ‘valid_07.csv’, ‘valid_08.csv’, ‘valid_09.csv’]
#定义文件分类函数
def get_filenames_by_prefix(source_dir, prefix_name):
all_files = os.listdir(source_dir)
results = []
for filename in all_files:
if filename.startswith(prefix_name):
results.append(os.path.join(source_dir, filename))
return results
train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir , "test")
pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)
[’./customize_generate_csv/train_00.csv’,
‘./customize_generate_csv/train_01.csv’,
‘./customize_generate_csv/train_02.csv’,
‘./customize_generate_csv/train_03.csv’,
‘./customize_generate_csv/train_04.csv’,
‘./customize_generate_csv/train_05.csv’,
‘./customize_generate_csv/train_06.csv’,
‘./customize_generate_csv/train_07.csv’,
‘./customize_generate_csv/train_08.csv’,
‘./customize_generate_csv/train_09.csv’,
‘./customize_generate_csv/train_10.csv’,
‘./customize_generate_csv/train_11.csv’,
‘./customize_generate_csv/train_12.csv’,
‘./customize_generate_csv/train_13.csv’,
‘./customize_generate_csv/train_14.csv’,
‘./customize_generate_csv/train_15.csv’,
‘./customize_generate_csv/train_16.csv’,
‘./customize_generate_csv/train_17.csv’,
‘./customize_generate_csv/train_18.csv’,
‘./customize_generate_csv/train_19.csv’]
[’./customize_generate_csv/valid_00.csv’,
‘./customize_generate_csv/valid_01.csv’,
‘./customize_generate_csv/valid_02.csv’,
‘./customize_generate_csv/valid_03.csv’,
‘./customize_generate_csv/valid_04.csv’,
‘./customize_generate_csv/valid_05.csv’,
‘./customize_generate_csv/valid_06.csv’,
‘./customize_generate_csv/valid_07.csv’,
‘./customize_generate_csv/valid_08.csv’,
‘./customize_generate_csv/valid_09.csv’]
[’./customize_generate_csv/test_00.csv’,
‘./customize_generate_csv/test_01.csv’,
‘./customize_generate_csv/test_02.csv’,
‘./customize_generate_csv/test_03.csv’,
‘./customize_generate_csv/test_04.csv’,
‘./customize_generate_csv/test_05.csv’,
‘./customize_generate_csv/test_06.csv’,
‘./customize_generate_csv/test_07.csv’,
‘./customize_generate_csv/test_08.csv’,
‘./customize_generate_csv/test_09.csv’]
#定义读取一行csv文件的函数
#n_fields : 数据列数
def parse_csv_line(line, n_fields = 9):
defs = [tf.constant(np.nan)] * n_fields
parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
#tf.stack() : 对矩阵进行拼接
x = tf.stack(parsed_fields[0:-1])
y = tf.stack(parsed_fields[-1:])
return x, y
#定义读取csv文件形成一个Dataset
#n_reader : 并行读取文件数
#n_parse_threads : 解析文件时的并行数
#shuffle_buffer_size : 混排buffe的大小
def csv_reader_dataset(filenames, n_reader=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000 ):
dataset = tf.data.Dataset.list_files(filenames)
#repeat(): 无参数表示重复无数次
#作用:在训练模型时我们不止一次使用数据,要多次使用训练集数据,通过epoch来终止
dataset = dataset.repeat()
#interleave() : 读取数据形成一个dataset
dataset = dataset.interleave(
lambda filename: tf.data.TextLineDataset(filename).skip(1),
cycle_length = n_reader
)
dataset.shuffle(shuffle_buffer_size)
#map():映射到tf.io.decode_csv()函数,解析数据
dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)
dataset = dataset.batch(batch_size)
return dataset
batch_size = 32
train_set = csv_reader_dataset(train_filenames, batch_size=batch_size)
valid_set = csv_reader_dataset(valid_filenames, batch_size=batch_size)
test_set = csv_reader_dataset(test_filenames, batch_size=batch_size)
#将数据存储为tfrecord格式
def serialize_example(x, y):
"""converts x, y to tf.train.Example and serialize"""
#需要注意是否需要转为numpy()形式
input_features = tf.train.FloatList(value = x.numpy())
label = tf.train.FloatList(value = y.numpy() )
features = tf.train.Features(
feature = {
"input_features": tf.train.Feature(float_list = input_features),
"label" : tf.train.Feature(float_list = label)
}
)
example = tf.train.Example(features = features)
return example.SerializeToString()
#n_shards :存储为n_shards文件
#steps_per_shard : 每个文件有多少条数据
def csv_dataset_to_tfrecords(base_filename, dataset, n_shards, setps_per_shard, compression_type = None):
options = tf.io.TFRecordOptions(compression_type = compression_type)
all_filenames = []
for shard_id in range(n_shards):
filename_fullpath = '{}_{:05d}-of-{:05d}'.format(base_filename, shard_id, n_shards)
with tf.io.TFRecordWriter(filename_fullpath, options) as write:
#需要写steps_per_shard次
for x_batch, y_batch in dataset.take(setps_per_shard):
for x_example, y_example in zip(x_batch, y_batch):
write.write(serialize_example(x_example, y_example))
all_filenames.append(filename_fullpath)
return all_filenames
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards
output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
train_basement = os.path.join(output_dir, "train")
valid_basement = os.path.join(output_dir, "vaild")
test_basement = os.path.join(output_dir, "test")
train_tfrecord_filenames = csv_dataset_to_tfrecords(train_basement, train_set, n_shards, train_steps_per_shard, None)
vaild_tfrecord_filenames = csv_dataset_to_tfrecords(valid_basement, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_filenames = csv_dataset_to_tfrecords(test_basement, test_set, n_shards, test_steps_per_shard, None)