如果所有输入数据都在内存中,那么从它们创建数据集的最简单方法是将它们转换为 tf.Tensor 对象,并且使用 Dataset.from_tensor_slices()。
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
print(dataset)
#
注意:使用Python生成器虽然方便,但它的可移植性和可伸缩性有限。它必须在创建生成器的同一个python进程中运行,并且仍然受python GIL的约束。
def count(stop):
i = 0
while i < stop:
yield i
i += 1
for n in count(5):
print(n)
# 0
# 1
# 2
# 3
# 4
Dataset.from_generator 构造函数将python生成器转换为 tf.data.Dataset 。
构造函数接受可调用的输入,而不是迭代器。这允许它在到达终点时重新启动生成器。它接受一个可选的args参数,该参数作为可调用参数传递。
output_types 参数是必需的,因为 tf.data 在内部建立 tf.Graph ,需要 tf.dtype。
ds_counter = tf.data.Dataset.from_generator(count,
args=[25],
output_types=tf.int32,
output_shapes=())
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
# [0 1 2 3 4 5 6 7 8 9]
# [10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24 0 1 2 3 4]
# [ 5 6 7 8 9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22 23 24]
# [0 1 2 3 4 5 6 7 8 9]
# [10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24 0 1 2 3 4]
# [ 5 6 7 8 9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22 23 24]
如果特定轴的长度未知或可变,output_shapes设置为None。还需要注意的是,output_shapes 和 output_types 遵循与其他数据集方法相同的嵌套规则。
下面是一个例子,它返回数组的元组,其中第二个数组是长度未知的向量。
def gen_series():
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
for i, series in gen_series():
print(i, ":", str(series))
if i > 5:
break
# 0 : [-0.4164 1.7885 -1.1574]
# 1 : [ 0.336 -0.7567 1.1686 0.737 ]
# 2 : [ 1.4963e+00 -1.9522e-03 1.2167e+00 -2.3682e+00 8.8495e-01 -4.0644e-01 -1.1557e+00]
# 3 : [ 1.7163 -0.4952 0.6011 0.627 ]
# 4 : []
# 5 : [ 0.5659 0.7346 1.2605 -0.3007 -1.7873 0.5895 -0.9043 -0.0809]
# 6 : [-0.9592 0.7137 -0.6669 -0.2512 -0.6094 -0.3598 0.4001 0.5433]
第一个输出是int32,第二个输出是float32。
第一项是标量,形状为();第二项是长度未知的向量,形状为(None,)。
ds_series = tf.data.Dataset.from_generator(
gen_series,
output_types=(tf.int32, tf.float32),
output_shapes=((), (None,)))
print(ds_series)
#
注意,当批处理具有可变形状的数据集时,需要使用 Dataset.padded_batch。
ds_series_batch = ds_series.shuffle(20).padded_batch(10)
ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print(sequence_batch.numpy())
# [ 6 15 16 19 12 21 20 14 1 18]
# [[ 1.3188 1.9454 -0.1828 0.702 -0.5487 -0.0621 0.3369 0.9619 0. ]
# [-0.0538 0. 0. 0. 0. 0. 0. 0. 0. ]
# [ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
# [-0.5753 -0.1205 1.2596 0.8157 0.3531 -0.5514 0.6236 0.698 0.788 ]
# [ 0.2008 1.0593 0.6685 -1.1157 1.2345 -0.8531 0. 0. 0. ]
# [ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
# [-0.7845 -0.4524 -0.0078 -1.3479 -0.8943 0.6126 -0.3543 0.4257 0. ]
# [-1.695 -0.2237 0. 0. 0. 0. 0. 0. 0. ]
# [-0.8229 0.1323 -0.3087 0. 0. 0. 0. 0. 0. ]
# [ 0.1945 -0.6999 0.3324 -1.1039 -1.8419 -0.1009 0. 0. 0. ]]
TFRecord文件格式是一种面向记录的简单二进制格式,许多TensorFlow应用程序使用它来训练数据。tf.data.TFRecordDataset 类能够将一个或多个TFRecord文件的内容作为输入的一部分进行流式传输。
下面是一个使用来自法国街道名称标志(FSNS)的测试文件的示例。
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames=[fsns_test_file])
raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())
tf.data.TextLineDataset 提供了从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset 将为这些文件的每行生成一个字符串值元素。
1)下载txt文件
directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']
file_paths = [
tf.keras.utils.get_file(file_name, directory_url + file_name)
for file_name in file_names
]
dataset = tf.data.TextLineDataset(file_paths)
for line in dataset.take(5):
print(line.numpy())
# b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
# b'His wrath pernicious, who ten thousand woes'
# b"Caused to Achaia's host, sent many a soul"
# b'Illustrious into Ades premature,'
# b'And Heroes gave (so stood the will of Jove)'
3)使用 Dataset.interleave 交替读取文件中的行内容,这样可以更容易地将文件混合在一起。以下是每个文本的第一行、第二行和第三行:
files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)
for i, line in enumerate(lines_ds.take(9)):
if i % 3 == 0:
print()
print(line.numpy())
# 每个文本的第一行
# b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
# b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
# b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'
#
#每个文本的第二行
# b'His wrath pernicious, who ten thousand woes'
# b'The vengeance, deep and deadly; whence to Greece'
# b'countless ills upon the Achaeans. Many a brave soul did it send'
#
#每个文本的第三行
# b"Caused to Achaia's host, sent many a soul"
# b'Unnumbered ills arose; which many a soul'
# b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'
4)默认情况下,TextLineDataset 生成每个文件的每一行,如果文件以标题行开头或包含注释,则可以使用 Dataset.skip() 或 Dataset.filter() 转换。这里下载带有标题的txt文件,然后跳过它的标题,并使用filter()过滤掉survived等于0的行,最后显示筛选后的前十行数据。
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
for line in titanic_lines.take(10):
print(line.numpy())
# b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
# b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
# b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
# b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
# b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
# b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
# b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
# b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
# b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
# b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
return tf.not_equal(tf.strings.substr(line, 0, 1), "0")
survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
print(line.numpy())
# b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
# b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
# b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
# b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
# b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
# b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
# b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
# b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
# b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
# b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'
CSV文件格式是以纯文本存储表格数据的常用格式。
例如:
import pandas as pd
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
print(df.head())
# survived sex age ... deck embark_town alone
# 0 0 male 22.0 ... unknown Southampton n
# 1 1 female 38.0 ... C Cherbourg n
# 2 1 female 26.0 ... unknown Southampton y
# 3 1 female 35.0 ... C Southampton n
# 4 0 male 28.0 ... unknown Queenstown y
使用 Dataset.from_tensor_slices 方法导入数据
titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))
for feature_batch in titanic_slices.take(1):
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
# 'survived' : 0
# 'sex' : b'male'
# 'age' : 22.0
# 'n_siblings_spouses': 1
# 'parch' : 0
# 'fare' : 7.25
# 'class' : b'Third'
# 'deck' : b'unknown'
# 'embark_town' : b'Southampton'
# 'alone' : b'n'
experimental.make_csv_dataset 函数是读取csv文件的高级接口。它支持列类型推断和许多其他特性,如批处理。
titanic_batches = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=4,
label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
print("'survived': {}".format(label_batch))
print("features:")
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
# 'survived': [0 0 0 1]
# features:
# 'sex' : [b'female' b'male' b'male' b'male']
# 'age' : [18. 33. 33. 26.]
# 'n_siblings_spouses': [1 0 1 0]
# 'parch' : [0 0 1 0]
# 'fare' : [17.8 8.6542 20.525 18.7875]
# 'class' : [b'Third' b'Third' b'Third' b'Third']
# 'deck' : [b'unknown' b'unknown' b'unknown' b'unknown']
# 'embark_town' : [b'Southampton' b'Southampton' b'Southampton' b'Cherbourg']
# 'alone' : [b'n' b'y' b'n' b'y']
如果只需要列的子集,则可以使用select_columns参数。
titanic_batches = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=4,
label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
print("'survived': {}".format(label_batch))
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
# 'survived': [0 0 1 0]
# 'fare' : [10.5 26.55 49.5042 73.5 ]
# 'class' : [b'Second' b'First' b'First' b'Second']
更低级的 experimental.CsvDataset 类提供更细的控制。它不支持列类型推断。相反,必须指定每列的类型。
titanic_types = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string,
tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types, header=True)
for line in dataset.take(10):
print([item.numpy() for item in line])
# [0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
# [1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
# [1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
# [0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
# [0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
# [1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
# [0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']