TensorFlow2——tf.data读取输入数据(二)

文章目录

  • 1、NumPy数组
  • 2、Python生成器
  • 3、TFRecord数据
  • 4、text数据
  • 5、CSV数据

 

1、NumPy数组

如果所有输入数据都在内存中,那么从它们创建数据集的最简单方法是将它们转换为 tf.Tensor 对象,并且使用 Dataset.from_tensor_slices()。

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
print(dataset)
# 

 

2、Python生成器

注意:使用Python生成器虽然方便,但它的可移植性和可伸缩性有限。它必须在创建生成器的同一个python进程中运行,并且仍然受python GIL的约束。

def count(stop):
    i = 0
    while i < stop:
        yield i
        i += 1

for n in count(5):
    print(n)
# 0
# 1
# 2
# 3
# 4

Dataset.from_generator 构造函数将python生成器转换为 tf.data.Dataset 。

构造函数接受可调用的输入,而不是迭代器。这允许它在到达终点时重新启动生成器。它接受一个可选的args参数,该参数作为可调用参数传递。

output_types 参数是必需的,因为 tf.data 在内部建立 tf.Graph ,需要 tf.dtype。

ds_counter = tf.data.Dataset.from_generator(count,
                                            args=[25],
                                            output_types=tf.int32,
                                            output_shapes=())
for count_batch in ds_counter.repeat().batch(10).take(10):
    print(count_batch.numpy())
# [0 1 2 3 4 5 6 7 8 9]
# [10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24  0  1  2  3  4]
# [ 5  6  7  8  9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22 23 24]
# [0 1 2 3 4 5 6 7 8 9]
# [10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24  0  1  2  3  4]
# [ 5  6  7  8  9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22 23 24]

如果特定轴的长度未知或可变,output_shapes设置为None。还需要注意的是,output_shapes 和 output_types 遵循与其他数据集方法相同的嵌套规则。

下面是一个例子,它返回数组的元组,其中第二个数组是长度未知的向量。

def gen_series():
    i = 0
    while True:
        size = np.random.randint(0, 10)
        yield i, np.random.normal(size=(size,))
        i += 1

for i, series in gen_series():
    print(i, ":", str(series))
    if i > 5:
        break
# 0 : [-0.4164  1.7885 -1.1574]
# 1 : [ 0.336  -0.7567  1.1686  0.737 ]
# 2 : [ 1.4963e+00 -1.9522e-03  1.2167e+00 -2.3682e+00  8.8495e-01 -4.0644e-01 -1.1557e+00]
# 3 : [ 1.7163 -0.4952  0.6011  0.627 ]
# 4 : []
# 5 : [ 0.5659  0.7346  1.2605 -0.3007 -1.7873  0.5895 -0.9043 -0.0809]
# 6 : [-0.9592  0.7137 -0.6669 -0.2512 -0.6094 -0.3598  0.4001  0.5433]

第一个输出是int32,第二个输出是float32。
第一项是标量,形状为();第二项是长度未知的向量,形状为(None,)。

ds_series = tf.data.Dataset.from_generator(
    gen_series,
    output_types=(tf.int32, tf.float32),
    output_shapes=((), (None,)))
print(ds_series)
# 

注意,当批处理具有可变形状的数据集时,需要使用 Dataset.padded_batch。

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print(sequence_batch.numpy())
# [ 6 15 16 19 12 21 20 14  1 18]
# [[ 1.3188  1.9454 -0.1828  0.702  -0.5487 -0.0621  0.3369  0.9619  0.    ]
#  [-0.0538  0.      0.      0.      0.      0.      0.      0.      0.    ]
#  [ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
#  [-0.5753 -0.1205  1.2596  0.8157  0.3531 -0.5514  0.6236  0.698   0.788 ]
#  [ 0.2008  1.0593  0.6685 -1.1157  1.2345 -0.8531  0.      0.      0.    ]
#  [ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
#  [-0.7845 -0.4524 -0.0078 -1.3479 -0.8943  0.6126 -0.3543  0.4257  0.    ]
#  [-1.695  -0.2237  0.      0.      0.      0.      0.      0.      0.    ]
#  [-0.8229  0.1323 -0.3087  0.      0.      0.      0.      0.      0.    ]
#  [ 0.1945 -0.6999  0.3324 -1.1039 -1.8419 -0.1009  0.      0.      0.    ]]

 

3、TFRecord数据

TFRecord文件格式是一种面向记录的简单二进制格式,许多TensorFlow应用程序使用它来训练数据。tf.data.TFRecordDataset 类能够将一个或多个TFRecord文件的内容作为输入的一部分进行流式传输。

下面是一个使用来自法国街道名称标志(FSNS)的测试文件的示例。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames=[fsns_test_file])

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

 

4、text数据

tf.data.TextLineDataset 提供了从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset 将为这些文件的每行生成一个字符串值元素。

1)下载txt文件

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]

在这里插入图片描述2)读取第一个文件(cowper.txt)的前五行

dataset = tf.data.TextLineDataset(file_paths)
for line in dataset.take(5):
    print(line.numpy())
# b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
# b'His wrath pernicious, who ten thousand woes'
# b"Caused to Achaia's host, sent many a soul"
# b'Illustrious into Ades premature,'
# b'And Heroes gave (so stood the will of Jove)'

3)使用 Dataset.interleave 交替读取文件中的行内容,这样可以更容易地将文件混合在一起。以下是每个文本的第一行、第二行和第三行:

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
    if i % 3 == 0:
        print()
    print(line.numpy())
# 每个文本的第一行
# b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
# b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
# b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'
# 
#每个文本的第二行
# b'His wrath pernicious, who ten thousand woes'
# b'The vengeance, deep and deadly; whence to Greece'
# b'countless ills upon the Achaeans. Many a brave soul did it send'
# 
#每个文本的第三行
# b"Caused to Achaia's host, sent many a soul"
# b'Unnumbered ills arose; which many a soul'
# b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

4)默认情况下,TextLineDataset 生成每个文件的每一行,如果文件以标题行开头或包含注释,则可以使用 Dataset.skip() 或 Dataset.filter() 转换。这里下载带有标题的txt文件,然后跳过它的标题,并使用filter()过滤掉survived等于0的行,最后显示筛选后的前十行数据。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
for line in titanic_lines.take(10):
    print(line.numpy())
# b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
# b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
# b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
# b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
# b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
# b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
# b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
# b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
# b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
# b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'

def survived(line):
    return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)

for line in survivors.take(10):
    print(line.numpy())
# b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
# b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
# b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
# b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
# b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
# b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
# b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
# b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
# b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
# b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

 

5、CSV数据

CSV文件格式是以纯文本存储表格数据的常用格式。
例如:

import pandas as pd
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
print(df.head())
#    survived     sex   age  ...     deck  embark_town  alone
# 0         0    male  22.0  ...  unknown  Southampton      n
# 1         1  female  38.0  ...        C    Cherbourg      n
# 2         1  female  26.0  ...  unknown  Southampton      y
# 3         1  female  35.0  ...        C  Southampton      n
# 4         0    male  28.0  ...  unknown   Queenstown      y

使用 Dataset.from_tensor_slices 方法导入数据

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
    for key, value in feature_batch.items():
        print("  {!r:20s}: {}".format(key, value))
# 'survived'          : 0
# 'sex'               : b'male'
# 'age'               : 22.0
# 'n_siblings_spouses': 1
# 'parch'             : 0
# 'fare'              : 7.25
# 'class'             : b'Third'
# 'deck'              : b'unknown'
# 'embark_town'       : b'Southampton'
# 'alone'             : b'n'

experimental.make_csv_dataset 函数是读取csv文件的高级接口。它支持列类型推断和许多其他特性,如批处理。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")

for feature_batch, label_batch in titanic_batches.take(1):
    print("'survived': {}".format(label_batch))
    print("features:")
    for key, value in feature_batch.items():
        print("  {!r:20s}: {}".format(key, value))
# 'survived': [0 0 0 1]
# features:
#   'sex'               : [b'female' b'male' b'male' b'male']
#   'age'               : [18. 33. 33. 26.]
#   'n_siblings_spouses': [1 0 1 0]
#   'parch'             : [0 0 1 0]
#   'fare'              : [17.8     8.6542 20.525  18.7875]
#   'class'             : [b'Third' b'Third' b'Third' b'Third']
#   'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
#   'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Cherbourg']
#   'alone'             : [b'n' b'y' b'n' b'y']

如果只需要列的子集,则可以使用select_columns参数。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])

for feature_batch, label_batch in titanic_batches.take(1):
    print("'survived': {}".format(label_batch))
    for key, value in feature_batch.items():
        print("  {!r:20s}: {}".format(key, value))
# 'survived': [0 0 1 0]
#   'fare'              : [10.5    26.55   49.5042 73.5   ]
#   'class'             : [b'Second' b'First' b'First' b'Second']

更低级的 experimental.CsvDataset 类提供更细的控制。它不支持列类型推断。相反,必须指定每列的类型。

titanic_types = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string,
                 tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types, header=True)

for line in dataset.take(10):
    print([item.numpy() for item in line])
# [0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
# [1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
# [1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
# [0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
# [0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
# [1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
# [1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
# [0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

你可能感兴趣的:(TensorFlow,2)