TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'

本笔记参照TensorFlow官方教程,主要是对‘tf.data: Build TensorFlow input pipelines’教程内容翻译和内容结构编排,原文链接:tf.data: Build TensorFlow input pipelines

目录
一、基本结构(Basic mechanics)
1.1 数据集结构
二、读取输入数据
2.1 处理Numpy 数组(Consuming Numpy arrays)
2.2 处理Python生成器(Consuming Python generators)
2.3 处理TFRecord数据(Consuming TFRecord data)
2.4 处理文本数据(Consuming text data)
2.5 处理CSV数据(Consuming CSV data)
2.6 处理文件集(Consuming sets of files)
三、批处理数据集元素(Batching dataset elements)
3.1 简单批处理(simple batching)
3.2 用‘填充’来批处理张量(Batching tensors with padding)
四、训练工作流(training workflows)
4.1 处理多纪元(Processing multiple epochs)
4.2 随机打乱输入数据(Randomly shuffling input data)
五、预处理数据(Preprocessing data)
5.1 图像数据解码和调整大小(Decoding image data and resizing it)
5.2 应用任意Python逻辑(Applying arbitrary Python logic)
5.3 解析tf.Example协议缓冲区消息示例(Parsing tf.Example protocol buffer messages)
5.4 时间序列窗口(Time series windowing)
5.5 重采样(resampling)
六、使用高阶API
6.1 tf.keras
6.2 tf.estimator


我们可以用tf.data API从简单可重用的片数据中创建负责的输入流水线。例如:图像模型的管道可以聚合来自分布式文件系统中的文件数据,对每个图像应用随机扰动,并将随机选择的图像合并成一批进行培训。文本模型的管道可能涉及从原始文本数据中提取符号,将其转换为带有查找表的嵌入标识符,并将不同长度的序列组合在一起。tf.data API使处理大量数据、从不同的数据格式读取数据和执行复杂的转换成为可能。

tf.data API 引入一个tf.data.Dataset抽象,它表示元素序列,其中每个元素由一个或多个组件组成。例如,在一个图像管道中,一个元素可能是一个单一的训练示例,它有一对张量分量表示图像及其标签。

创建数据集有两种不同的方法:
- 从存储在内存的一个或多个文件中的数据构建数据集
- 从一个或多个tf.data.Dataset对象转换创建数据集

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

一、基本结构(Basic mechanics)

为了创建一个输入流水线,我们必须从数据源开始。例如,从内存中创建一个数据集,我们可以使用tf.data.Dataset.from_tensors()或者tf.data.Dataset.from_tensor_slices()。或者,如果我们的输入数据以推荐的形式(TFRecord)存储在文件里,我们可以用tf.data.TFRecordDataset()。
一旦我们有了一个数据集对象,我们可以通过调用tf.data.Dataset对象中的链接方法将它转换为一个新的数据集。例如,我们可以应用预执行转换比如Dataset.map(),多元素转换如Dataset.batch()。详情参考:tf.data.Dataset
Dataset对象是一个可迭代的Python对象。这使得使用for循环来消费它的元素成为可能:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset


for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

或者使用iter显式地创建一个Python迭代器,然后使用next消费它的元素:

it = iter(dataset)

print(next(it).numpy())
8

或者,可以使用‘reduce’转换使用数据集元素,该转换减少所有元素以生成单个结果。下面的示例说明如何使用‘reduce’转换来计算整数集的和。

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

1.1 数据集结构
数据集包含的元素具有相同的(嵌套的)结构,结构的各个组件可以是tf.TypeSpec来表示的任何类型,包括张量、稀疏张量、不规则张量、张量阵列或数据集。
‘Dataset.element_spec’属性让我们可以检查每个元素组件的类型。该属性返回一个tf.TypeSpec对象的嵌套结构,匹配元素的结构,元素可以是单个组件、组件的元组或组件的嵌套元组。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

数据集转换支持任何结构的数据集。在使用将函数应用于每个元素的Dataset.map()和Dataset.filter()转换时,元素结构决定了函数的参数:

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1

for z in dataset1:
  print(z.numpy())
[1 4 5 3 5 8 8 3 9 6]
[2 2 1 6 5 8 7 7 2 9]
[5 6 4 3 7 4 9 5 6 6]
[8 7 8 5 7 2 2 6 5 4]
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3

for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

二、读取输入数据
2.1 处理Numpy 数组(Consuming Numpy arrays)
更多示例请参考(需备梯子):Loading Numpy arrays
如果我们所有的输入数据存在内存里,那创建数据集最简单的方法是使用‘Dataset.from_tensor_slices()’方法将它们转换为‘tf.Tensor’。

train, test = tf.keras.datasets.fashion_mnist.load_data()

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第1张图片

images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset

	注意:上面的代码片段把特性和标签数组作为‘tf.constant()’操作嵌入到TensorFlow图中,
	这对于小数据集来说工作的很好,但是会浪费内存---因为数组的内容会被复制多次---并且
	可能会达到‘tf.GraphDef’协议缓冲区的2GB限制。

2.2 处理Python生成器(Consuming Python generators)
另外一个可以很容易被整合成‘tf.data.Dataset’的通用数据源是Python生成器。

		注意:虽然这是一种方便的方法,但它的可移植性和可靠性有限。它必须在与创建生成器相同的python进程中运行,并且仍然受python GIL的约束。
def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)   
0
1
2
3
4

‘Dataset.from_generator’构造函数将python生成器转换为全功能的‘tf.data.Dataset’。构造函数需要一个可调用对象最为输入,而不是迭代器。这样当生成器结束时可以让它重启。它(constructor)也有一个可选参数‘args’,作为一个可调用参数传递。
output_types参数是必需的因为‘tf.data’在内部创建‘tf.Graph’,而图边界需要‘tf.dtype’。

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

‘output_shapes’参数虽然不是必需的,但是强烈建议添加,因为许多TensorFlow操作不支持秩未知的张量。如果某个轴的长度未知或可变,则在output_shapes中将其设置为None。
还需要注意的是,output_shapes和output_types作为其他数据集方法时遵循相同的嵌套规则。
下面是一个演示这两个方面的生成器示例,它返回数组的元组,其中第二个数组是长度未知的向量。

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [-1.1226  1.6132 -0.0095  0.8728]
1 : [0.6396 0.4688 0.2611 0.9847 0.1679 0.0287]
2 : [-0.2065  0.2807 -1.1219  0.0603]
3 : [ 1.137   0.0087 -0.774  -0.321  -2.0574 -0.4246]
4 : [ 1.0536 -1.0681 -0.8049  0.5107  1.2738 -0.1986 -0.5262  0.7247 -0.1688]
5 : [-1.7257 -0.4691  0.418   1.7976  1.863   0.3992]
6 : [-0.3747 -1.2524  0.525  -0.6958  0.4991 -0.5964 -1.7148]

第一个输出是‘int32’,第二个输出是‘float32’。第一个条目是标量,shape(),第二个是向量,长度未知,shape(None)。

ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series

现在,它可以被用作正常的tf.data.Dataset了。注意:当批处理一个形状可变的数据集时,我们需使用‘Dataset.padded_batch’。

ds_series_batch = ds_series.shuffle(20).padded_batch(10, padded_shapes=([], [None]))

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[12 19  3  0  9  7  5  6  4 16]

[[ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.9723 -0.4083 -0.0498 -0.9856  0.      0.      0.      0.      0.    ]
 [-1.2603  0.8078 -0.6713  0.0692  0.1462 -0.7181 -0.0713  0.801   0.    ]
 [ 1.6164  0.5583  1.0472  1.5479  0.4733  0.2503 -0.5349  1.0763 -0.385 ]
 [ 1.1203 -0.7176  0.3693 -0.2975 -1.5206  1.297   0.5356 -1.2834 -0.7963]
 [-1.7671 -0.723  -0.3565  1.2658  0.6733  0.106  -0.5957  0.      0.    ]
 [ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
 [ 0.422   0.4992 -1.5497 -0.6262 -0.1558 -0.2029  0.      0.      0.    ]
 [ 0.5232  0.8569  0.0893 -0.3251 -0.9755  1.0572 -1.5325 -1.1672  0.    ]
 [ 1.7647  0.0114  2.0847  0.6158  0.      0.      0.      0.      0.    ]]

对于更实际的示例,请尝试包装预处理image.ImageDataGenerator作为tf.data.Dataset。
先下载数据:

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

在这里插入图片描述
创建‘image.ImageDataGenerator’

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)
ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers], 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds

2.3 处理TFRecord数据(Consuming TFRecord data)
有关端到端的例子,请参考:‘Loading TFRecords’。(官方链接丢失)
‘tf.data’API支持多种文件格式,因此可以处理内存中不适合的大型数据集。例如,TFRecord文件格式是一个简单的面向记录的二进制格式,许多TensorFlow应用程序将其用于训练数据。‘tf.data.TFRecordDataset’类允许我们将一个或多个TFRecord文件的内容作为输入管道的一部分。
下面是一个使用来自法国街道名标牌(FSNS)的测试文件的示例。

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")

TFRecordDataset初始化器的文件名参数可以是字符串、字符串列表或字符串张量。如果我们有两组用于训练和验证的文件,我们可以创建一个工厂方法来生产数据集,将文件名作为输入参数:

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

许多TensorFlow项目在TFRecord文件里使用序列化‘tf.train.Example’记录。这些需要在检查前进行解码:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

2.4 处理文本数据(Consuming text data)
有关端到端的例子,请参考:Loading text。
许多数据集作为一个或多个文本文件分发。‘tf.data.TextLineDataset’提供了一种从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset将为这些文件的每行生成一个字符串元素。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第2张图片

dataset = tf.data.TextLineDataset(file_paths)

下面是一个文件的前几行:

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

要在文件之间交替行,请使用‘Dataset.interleave’。这使得将文件混合在一起变得更加容易。以下是每个译本的第一行、第二行和第三行:

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

默认情况下,TextLineDataset会生成每个文件的每一行,这可能是不需要的,例如:如果文件以标题行开始,或者包含注释。可以使用Dataset.skip()或Dataset.filter()转换删除这些行。在这里,我们跳过第一行,然后过滤直到找到正文。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

2.5 处理CSV数据(Consuming CSV data)
更多示例请参考:Loading CSV Files和‘Loading Pandas DataFrames’(官方链接丢失)。
CSV文件格式是一种以纯文本形式存储表格数据的流行格式。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
df.head()

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第3张图片
如果我们的数据在内存中,‘Dataset.from_tensor_slices’方法同样作用于字典,使这些数据很轻易地被使用。

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

一种更具可伸缩性的方法是根据需要从磁盘加载。tf.data模块提供了从一个或多个符合RFC 4180的CSV文件中提取记录的方法。它支持列类型推断和许多其他特性,比如批处理和变换,以简化使用。
‘experimental.make_csv_dataset’是一个高级别交互函数,用来读取CSV文件集。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 0 0 1]
features:
  'sex'               : [b'male' b'male' b'male' b'male']
  'age'               : [28. 31. 38. 28.]
  'n_siblings_spouses': [3 0 0 0]
  'parch'             : [1 0 0 0]
  'fare'              : [25.4667 50.4958  7.05   30.5   ]
  'class'             : [b'Third' b'First' b'Third' b'First']
  'deck'              : [b'unknown' b'A' b'unknown' b'C']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'n' b'y' b'y' b'y']

如果只需要列的一个子集,那么可以使用select_columns参数。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 0 0 0]
  'fare'              : [27.7208 13.     20.2125  9.5   ]
  'class'             : [b'First' b'Second' b'Third' b'Third']

还有一个低级别‘experimenta.CsvDataset’类,它不支持列类型推断。相反,必须指定每个列的类型。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] 
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

如果某些列是空的,则此低级接口允许我们提供默认值,而不是列类型。

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

默认情况下,CsvDataset会生成文件每行中的每一列,这可能是不需要的,例如,如果文件以一个应该忽略的标题行开始,或者输入中不需要某些列。可以分别使用header和select_cols参数删除这些行和字段。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

2.6 处理文件集(Consuming sets of files)
有许多数据集分布在一个文件集里,其中,每个文件是一个示例。

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

根目录包含每个类的路径:

for item in flowers_root.glob("*"):
  print(item.name)
daisy
sunflowers
roses
LICENSE.txt
tulips
dandelion

要从文件中加载数据,请使用tf.io.read_file功能:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/root/.keras/datasets/flower_photos/tulips/16680998737_6f6225fe36.jpg'
b'/root/.keras/datasets/flower_photos/roses/8437935944_aab997560a_n.jpg'
b'/root/.keras/datasets/flower_photos/dandelion/808239968_318722e4db.jpg'
b'/root/.keras/datasets/flower_photos/dandelion/3502447188_ab4a5055ac_m.jpg'
b'/root/.keras/datasets/flower_photos/roses/3145692843_d46ba4703c.jpg'

将文件路径转换为(图像,标签)对:

def process_path(file_path):
  parts = tf.strings.split(file_path, '/')
  return tf.io.read_file(file_path), parts[-2]

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00'

b'tulips'

三、批处理数据集元素(Batching dataset elements)
3.1 简单批处理(simple batching)
最简单的批处理方式是将数据集的n个连续元素堆叠成单个元素。Dataset.batch()变换就是这样做的,它有和tf.stack()运算符相同的约束,适用于元素的每个分量:也就是说,对于每个分量i,所有的元素都必须有一个完全相同形状的张量。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

当tf.data试图传播形状信息时,默认情况下,Dataset.batch的结果批大小是未知的,因为最后一批可能没有满。注意‘shape’中的Nones参数:

batched_dataset

使用‘drop_remainder’参数来忽略最后一个批,得到完整的形状传播(shape propagation)

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset

3.2 用‘填充’来批处理张量(Batching tensors with padding)
以上方法适用于具有相同大小的张量。然而,许多模型(例如序列模型)处理的输入数据可能具有不同的大小(例如不同长度的序列)。为了处理这种情况,Dataset.padded_batch()转换允许我们通过指定一个或多个维度来填充不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()
[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]

Dataset.padded_batch()转换允许我们为每个组件的每个维度设置不同的填充,它可以是可变长度(在上面的示例中由None表示)或固定长度。还可以覆盖padding值,该值默认为0。

四、训练工作流(training workflows)
4.1 处理多纪元(Processing multiple epochs)
tf.data API提供两种主要方法来处理同一数据的多个纪元。在多个epoch中遍历数据集的最简单方法是使用dataset.repeat()转换。
首先,我们创建一个包含‘titanic data’的数据集:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

使用不带参数的Dataset.repeat()转换将无限期地重复输入
Dataset.repeat转换连接它的参数不需要发出一个纪元结束和一个纪元开始的信号。因为随后使用的Dataset.batch()将产出可以跨纪元边界的批:

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第4张图片
如果我们需要明确的纪元分离,把Dataset.batch放在‘repeat’前面:

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第5张图片
如果你想在每个epoch结束时执行一个自定义计算(例如收集统计数据),那么最简单的方法是在每个epoch上重新启动数据集迭代:

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

4.2 随机打乱输入数据(Randomly shuffling input data)
Dataset.shuffle()转换维持一个固定大小的缓冲,并从该缓冲区均匀随机地选择下一个元素。

	注意:虽然大的buffer_size会更彻底地洗牌,但是它们会占用大量内存,并且需要大量时间来填充。如果出现问题,可以考虑在文件之间使用Dataset.interleave。

添加一个索引到数据集,这样我们可以看到效果:

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset

因为buffer_size是100,而批处理大小是20,所以第一批不包含索引超过120的元素。

n,line_batch = next(iter(dataset))
print(n.numpy())
[ 36   7  48  88  60  52  61  59  97 100  45  62  87  34  90  83  55  14
  32 114]

由于有Dataset.batch,Dataset.repeat的顺序很重要。
在转移缓冲区为空之前,Dataset.shuffle并不会发出一个纪元结束的信号。因此,重复前的转移将在转移到下一个纪元之前显示一个纪元的所有元素(So a shuffle placed before a repeat will show every element of one epoch before moving to the next):

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[565 571 381 557 594 627 503 383 562 513]
[556 603 407 604 613 496 165 252 624 623]
[484 464 572 602 569 605 574 610]
[25 52 94  2 43 85  7 44 26 74]
[ 41  49  95  61  76  16 104   0  67 105]
shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第6张图片
但在洗牌之前(shuffle),‘重复’(repeat)将纪元的界限混合在了一起:

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[545 487   5 583 133 418 609 384   6 620]
[458 597 477  28 595 586 400   3  36  33]
[623 535 491 587 589  42  21  41 601 399]
[ 40 361 607 498 416 575  34  53  50  27]
[520 599  38  56 619  54  17  52  59 314]
[ 15  72  30 574 582  35  31  26  13  61]
[ 29 549  39  69  11  23   9  37 530  46]
[526  20 521  45  57 566 270 555  86  95]
[ 70  12 544  66   2 573  76  82 553 110]
[617 107 608 624  49 104  75  67  92 611]
[ 24 101 614 102 618 105 602 109 111 112]
[482 578  77  88 598 464  10 579  79 106]
[ 97 605  64  71  85 108  58  74 141  19]
[145 140 129  22 475 120  99 446 137  32]
[150  84 557 147 588 130  48 136  51 163]
repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()


TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第7张图片
五、预处理数据(Preprocessing data)
Dataset.map(f)转换通过对输入数据集的每个元素应用一个给定函数f来生成一个新的数据集。它基于map()函数,该函数通常应用于函数式编程语言中的列表(和其他结构)。f函数以tf.Tensor对象来表示输入中的单个元素,并且也是返回tf.Tensor,表示新数据集中的单个元素。它的实现使用标准的TensorFlow操作将一个元素转换成另一个元素。
下面的内容就教我们如何使用‘Dataset.map()’
5.1 图像数据解码和调整大小(Decoding image data and resizing it)
当用真实图像数据来训练神经网络时,将不同大小的图像转换成统一大小的操作通常是必要,因此它们可能被批处理成固定大小。
重建‘flower filenames’数据集:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

编写一个控制数据元素的函数:

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(file_path, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

测试它是否有效:

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第8张图片
将它映射到数据集里:

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第9张图片
TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第10张图片
5.2 应用任意Python逻辑(Applying arbitrary Python logic)
出于性能的原因,谷歌鼓励我们如果可以使用TensorFlow操作来预处理我们的数据。然而,有时候调用外部的Python库来解析我们的输入数据也是有帮助的。我们可以在在Dataset.map()转换中使用tf.py_function()操作。
例如,如果我们想用个随机旋转,‘tf.image’模块仅仅只有‘tf.image.rot90’,这样对图像增强不是很有用。

	注意:tensorflow_addons中有一个与TensorFlow兼容的旋转,在tensorflow_addons.image.rotate里。为了演示‘tf.py_function’,我们可以尝试使用‘scipy.ndimage.rotate'函数替代:
import scipy.ndimage as ndimage

def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第11张图片
为了在Dataset.map中使用这个函数,跟使用Dataset.from_generator相同的注意事项,在应用函数时需要描述返回的形状和类型:

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第12张图片
TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第13张图片
5.3 解析tf.Example协议缓冲区消息示例(Parsing tf.Example protocol buffer messages)
许多输入流水线从TFRecord格式提取‘tf.train.Example’协议缓冲区消息。每个‘tf.train.Example’record包含一个或多个‘特征’,输入流水线通常将这些特征转换成张量。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

我们可以在‘tf.data.Dataset’外部,用‘tf.train.Example’模型来工作,以理解数据:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第14张图片

raw_example = next(iter(dataset))
def tf_parse(raw_examples):
  example = tf.io.parse_example(
      raw_example[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...
decoded = dataset.map(tf_parse)
decoded

image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

5.4 时间序列窗口(Time series windowing)
端到端的时间序列示例请参考:Time series forecasting(官方链接丢失)
时间序列数据通常以完整的时间轴来组织。先用一个简单的Dataset.range来演示:

range_ds = tf.data.Dataset.range(100000)

通常,基于这类数据的模型需要一个连续的时间片。最简单的方法是批处理这些数据:
5.4.1 使用批处理(batch)

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

或者为了对未来做一个密集的预测,你可以将特征和标签相对地移动一步:

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

要预测整个窗口而不是一个固定的偏移量,你可以把批次分成两部分:

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Take the first 5 steps
          batch[-5:])   # take the remainder

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

为了让一批数据的特性和另一批数据的标签有一些重叠,可以使用Dataset.zip:

feature_length = 10
label_length = 5

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])

predict_5_steps = tf.data.Dataset.zip((features, labels))

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32 33 34]

5.4.2 使用‘窗口’(window)
当使用Dataset.batch时,有些情况我们也许需要更好的控制。Dataset.window方法可以提供完全的控制,但需要注意:它返回的是数据集的数据集。有关详细信息,请参见:‘Dataset structure’(官方链接丢失)

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

Dataset.flat_map方法可以采取一个数据集的数据集(a dataset of datasets),并把它变成一个单一的数据集:

 for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
WARNING:tensorflow:Entity  at 0x7fe0dbe37950> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of  at 0x7fe0dbe37950>, which Python reported as:
for x in windows.flat_map(lambda x: x).take(30):

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.
WARNING: Entity  at 0x7fe0dbe37950> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of  at 0x7fe0dbe37950>, which Python reported as:
for x in windows.flat_map(lambda x: x).take(30):

If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9 

几乎所有的用例,我们需要先批处理数据集:

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

现在我们可以通过‘shift’参数看到每个窗口移动多少。将上面这些整合起来,我们可以编写出这个函数:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

想以前一样,提取标签是容易的:

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

5.5 重采样(resampling)

在处理非常‘类不平衡’(class-imbalanced)的数据集时,我们可能需要重新取数据集。tf.data为此提供了两种方法。信用卡欺诈数据集就是这类问题的一个很好的例子。
更多信息,请参考:Imbalanced Data(官方链接丢失)

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

现在,检查类的分布,它是高度倾斜的:

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9948 0.0052]

用不平和数据集进行训练的通用方法是让它平衡。‘tf.data’里有几个方法可以让这个工作流工作:
5.5.1 数据集重采样
数据集重采样的一种方法是‘sample_from_datasets’。当每个类
有个单独data.Dataset时更适用。
下面用滤波器(filter)从信用卡欺诈数据中生成一个重采样数据集:

negative_ds = creditcard_ds.unbatch().filter(lambda features,label: label==0).repeat()
positive_ds = creditcard_ds.unbatch().filter(lambda features,label: label==1).repeat()
for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

使用‘tf.data.experimental.sample_from_datasets’传递数据集以及权重值:

balanced_ds = tf.data.experimental.sample_from_datasets([negative_ds, positive_ds], [0.5, 0.5]).batch(10)

现在数据集产生每个类的例子的概率是50/50:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 1 1 1 0 0 1 0 1 0]
[0 0 0 0 0 1 1 0 0 0]
[0 1 1 0 1 0 0 1 0 0]
[0 1 0 0 0 1 1 1 0 1]
[1 1 0 1 0 1 0 1 0 1]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 1 0 1 0 1 1]
[0 1 1 0 1 0 1 1 1 0]
[0 0 0 1 0 1 0 0 1 1]
[1 1 1 0 1 1 0 0 0 0]

5.5.2 拒绝重采样(Rejection resampling)
上面的‘experimental.sample_from_datasets’方法的缺点是它需要一个单独的tf.data.Dataset给每个类。使用Dataset.filter可以解决,但所有的数据将被加载两次。
‘data.experimental.rejection_resample ’方法可以用来将数据集重平衡,并且只需要加载一次数据。为了达到平衡元素将从数据集中丢弃。
‘data.experimental.rejection_resample ’需要一个‘class_func‘’参数。这个参数应用于每个元素,并且用于确定一个示例属于哪个类,以达到平衡的目的。
‘creditcard_ds’里的元素已经配对了。所以‘class_func’只需要返回这些标签就行:

def class_func(features, label):
  return label

重采样器需要一个目标分布,以及一个可选的初始化分布估计:

resampler = tf.data.experimental.rejection_resample(
    class_func, target_dist=[0.5, 0.5], initial_dist=fractions)

重采样器处理单个示例,所以你必须在应用重采样器之前取消数据集:

resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)
WARNING:tensorflow:From /tensorflow-2.0.0/python3.6/tensorflow_core/python/data/experimental/ops/resampling.py:151: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

重采样器返回从‘class_func’输出中创建的(class,example)对。在这条用例中,‘example’已经是一个(feature,label)对了,所以使用‘map’丢弃多余的标签备份:

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

现在数据集产生每个类的例子的概率是50/50:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[1 0 0 1 0 0 1 1 1 0]
[0 0 1 0 0 0 0 1 1 1]
[0 0 1 1 1 0 0 1 0 1]
[0 0 1 0 1 1 1 1 0 1]
[1 1 1 0 0 0 0 0 1 1]
[1 1 1 1 1 1 0 1 1 1]
[1 0 1 0 0 1 0 0 0 1]
[1 1 0 0 1 1 1 1 0 0]
[1 0 0 1 0 0 1 0 0 1]
[1 0 0 0 1 0 0 0 1 0]

六、使用高阶API
6.1 tf.keras
tf.keras API在创建和执行机器学习模型上简化了许多。它的‘.fit()’、‘.evaluate()’、‘.predict()’ API接口支持数据集作为输入。下面是一个快速的数据集和模型设置:

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
              metrics=['accuracy'])

传递(feature,label)数据集只需要Model.fit和Model.evaluate:

model.fit(fmnist_train_ds, epochs=2)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第15张图片
如果我们传递一个无限大的数据集,比如调用‘Dataset.repeat()’,我们需要传递‘steps_per_epoch’参数:

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第16张图片
我们可以传递评价步骤数来进行评价:

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)

在这里插入图片描述
对于长数据集,设置好步数来评价:

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)

在这里插入图片描述
当调用Model.predict时不需要标签:

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
(320, 10)

如果我们还是传递了标签,它将被忽略:

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
(320, 10)

6.2 tf.estimator
在‘tf.estimator.Estimator’中的‘input_fn’使用数据集,只需从input_fn返回数据集,框架将为我们处理它的元素。例如:

import tensorflow_datasets as tfds

def train_input_fn():
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
age = tf.feature_column.numeric_column('age')
import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第17张图片

model = model.train(input_fn=train_input_fn, steps=100)

TensorFlow2.0 Guide 官方教程 学习笔记14-'tf.data: Build TensorFlow input pipelines'_第18张图片

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-10-26T11:09:36Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp1rrmmxl3/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Finished evaluation at 2019-10-26-11:09:37
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.728125, accuracy_baseline = 0.625, auc = 0.7681875, auc_precision_recall = 0.6506202, average_loss = 0.5628802, global_step = 100, label/mean = 0.375, loss = 0.5628802, precision = 0.72602737, prediction/mean = 0.33942837, recall = 0.44166666
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmp1rrmmxl3/model.ckpt-100
accuracy : 0.728125
accuracy_baseline : 0.625
auc : 0.7681875
auc_precision_recall : 0.6506202
average_loss : 0.5628802
label/mean : 0.375
loss : 0.5628802
precision : 0.72602737
prediction/mean : 0.33942837
recall : 0.44166666
global_step : 100
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp1rrmmxl3/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.4143]
logistic : [0.3979]
probabilities : [0.6021 0.3979]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

你可能感兴趣的:(TensorFlow学习笔记)