Tensorflow2.0学习笔记

tf.data API使用

  • 1 tf.data
    • 1.1 tf.data.Dataset
      • 1.1.1 Dataset的基础API
      • 1.1.2 从csv文件中创建Dataset数据集
      • 1.1.3 解析csv文件
    • 1.2 tf.data.TFRecordDataset
      • 1.2.1 创建TFRecord文件
      • 1.2.2 从TFRecord文件中创建TFRecordDataset
      • 1.2.3 解析TFRecord文件

1 tf.data

tf.data是tenosrflow2.0中用于创建数据集的模块,包含了多种类型的数据集类,其中DatasetTfRecordDataset是常用于封装数据集的类。Dataset包含了大量存在的数据。TFRecord是tf中特殊的数据格式,可提升数据读写速度。Tensorflow2.0学习笔记_第1张图片

1.1 tf.data.Dataset

1.1.1 Dataset的基础API

  1. from_tensor_slices :该方法用于创建Dataset,传入参数可以是列表、元组、字典、numpy数组,但各元素的第一维size必须相等,该操作将传入参数以切片的方式封装为一个数据集。
# 创建新的Dataset
# 1、从列表中创建Dataset
dataset = tf.data.Dataset.from_tensor_slices([1,2,3])
print("查看dataset元素值、形状和类型")
for element in dataset:
    print(element)
print("仅查看元素值")
for element in dataset.as_numpy_iterator():
    print(element)
# 2、从numpy数组中创建Dataset
arr = np.asarray(([1,2],[3,4],[5,6]))
dataset = tf.data.Dataset.from_tensor_slices(arr)
for element in dataset.as_numpy_iterator():
    print(element)
# 3、从元组中创建Dataset
dataset = tf.data.Dataset.from_tensor_slices(([[1, 2],[1,2]], [3, 4], [5, 6]))
for element in dataset.as_numpy_iterator():
    print(element)

# 元素的第一维的形状必须相同
dataset = tf.data.Dataset.from_tensor_slices(([1,2,3], [3, 4], [5, 6])) 
# ValueError: Dimensions 3 and 2 are not compatible

# 4、从字典中创建Dataset
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
for element in dataset.as_numpy_iterator():
    print(element)
 

2.repeat/batch:用于数据集的重复和批量化,可嵌套使用repeat和batch

dataset = tf.data.Dataset.from_tensor_slices([1,2,3])
dataset1 = dataset.repeat(4)
print(list(dataset1.as_numpy_iterator()))
dataset2 = dataset1.batch(5)
print(list(dataset2.as_numpy_iterator()))
# 如果希望输出的形状相同,drop_remainder=True
dataset3 = dataset.repeat(4).batch(5,drop_remainder=True)
print(list(dataset3.as_numpy_iterator()))
  1. interleave:用于生成和处理数据集,并且可以并行处理多个数据集。该方法的具体操作是:对cycle_length个输入元素应用map_func后产生新的Dataset,每个element作为容器并迭代,生成block_length个连续元素直至用尽。

interleave(
map_func, cycle_length=None, block_length=None, num_parallel_calls=None,
deterministic=None, name=None
)

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
    cycle_length=3, block_length=4)
list(dataset.as_numpy_iterator())

interleave还可以用于从csv文件中创建Dataset数据集,具体操作方法是:1.先创建包含文件名的文件名数据集(见1.1.2) 2.使用interleave方法,读取文件名数据集中的文件,将多个文件合并形成一个完整的数据集。

1.1.2 从csv文件中创建Dataset数据集

1.list_files:list_files用于匹配路径中的文件并生成文件名数据集,如果已经整理出匹配后的文件列表,可以直接使用from_tensor_slices创建文件名数据集。

# 1.使用from_tensor_slices创建文件名数据集
## 需要导入csv文件的目录并读取文件名,generate_csv文件中包含了trian,test,valid三种csv文件,提取train文件并存储在列表中。
filename_list = os.listdir('D:\\Projects_File\\Jupyter projects\\tensorflow2.0_course\\chapter_4\\generate_csv')
train_filename = []
for filename in filename_list:
    if not filename.find('train'):
        train_filename.append(filename)
filename_dataset = tf.data.Dataset.from_tensor_slices(train_filename)
for filename in filename_dataset:
    print(filename)
# 2.使用list_files创建文件名数据集,使用通配符匹配train文件
filename_dataset = tf.data.Dataset.list_files(".\\tensorflow2.0_course\\chapter_4\\generate_csv\\train*.csv",shuffle=False)
for idx, filename in enumerate(filename_dataset):
    print(idx, filename)
# 使用interleave将多个文件合并为完整的数据集,n_readers代表合并的文件个数
n_readers = 5
dataset = filename_dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length = n_readers,
    block_length = 2)
for line in dataset.take(15):
    print(line.numpy())

1.1.3 解析csv文件

csv是一种通用的、相对简单的逗号分隔值文件格式,是一种用来存储数据的纯文本文件;纯文本意味着CSV文件是一个字符序列,因此需将其解析为数值型数据才能用于深度学习。

  1. tf.io.decode_csv:将CSV文件转换为张量。每一列映射到一个张量。record_defaults是需要被解析的数据类型。列表长度需要和csv文件中的columns数量对应,否则报错。
def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y

parse_csv_line(b'-0.9868720801669367,0.832863080552588,-0.18684708416901633,-0.14888949288707784,-0.4532302419670616,-0.11504995754593579,1.6730974284189664,-0.7465496877362412,1.138',
               n_fields=9)

1.2 tf.data.TFRecordDataset

1.2.1 创建TFRecord文件

TFRecord 和 tf.Example的介绍: 为了高效地读取数据,比较有帮助的一种做法是对数据进行序列化并将其存储在一组可线性读取的文件(每个文件 100-200MB)中。这尤其适用于通过网络进行流式传输的数据。这种做法对缓冲任何数据预处理也十分有用。 tensorflow官方帮助文档链接

tf.Example的结构如下:
->tf.train.Example: tfrecord文件存储的内容为Example
---->tf.train.Features: Example包含多个Features,Features的格式为dict {“key”: tf.train.Feature}
-------->tf.train.Feature: Feature的值具有特定的格式——tf.train.ByteList/FloatList/Int64List

favorite_books = [name.encode('utf-8')
                  for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)

# 构建Features,格式为{"key": value}
features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(
            float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list),
    }
)
print(features)

# 序列化exmaple后写入tfrecord文件
example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()
print(serialized_example)

# 保存tfrecord文件
output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

1.2.2 从TFRecord文件中创建TFRecordDataset

dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

1.2.3 解析TFRecord文件

使用方法:tf.io.parse_example,解析TFRecord文件和解析csv文件类似,都要先设置需要被解析的数据类型。

expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string),
    "hours": tf.io.VarLenFeature(dtype = tf.float32),
    "age": tf.io.FixedLenFeature([], dtype = tf.int64),
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

你可能感兴趣的:(Tensorflow,学习,python,tensorflow)