tf.data模块

tf.data是tensorflow提供的用来构建模型输入流水线的模块,集成了map,reduce,batch,shuffle等功能,使用起来比较方便,最佳的自然去看官网链接,这里只是我的学习记录。

tf.data.Dataset.from_tensor_slices

  • 传入一维的list,输出的是scalar
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> list(dataset.as_numpy_iterator())
[1, 2, 3]

传入二维的tensor,输出一维的tensor

>>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
>>> list(dataset.as_numpy_iterator())
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]

传入字典,字典的value对应的是tensor,输出的也是字典,key不变,value正常切割

>>> # Dictionary structure is also preserved.
>>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
>>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
...                                       {'a': 2, 'b': 4}]
True

传入tuple构成的数据,相当于tuple内部的元素依次切割,并然后在组合起来

>>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
>>> list(dataset.as_numpy_iterator())
[(1, 3, 5), (2, 4, 6)]

# 这种格式在tf.keras的model.fit里面很常用,尤其是对于有多个输入和输出的时候,可以用key去指定,这里前一个字典代表了feature,后一个字典代表了label
>>> dataset = tf.data.Dataset.from_tensor_slices(({"a": [1, 2]}, {"b": [3, 4]}))
>>> list(dataset.as_numpy_iterator())
[({'a': 1}, {'b': 3}), ({'a': 2}, {'b': 4})]

注意:这里不支持non-rectangular形式的输入tensor,比如这种就不行,然而使用from_generator可以接受不一样的tensor

dataset = tf.data.Dataset.from_tensor_slices([[1], [2,3]])
print(list(dataset.as_numpy_iterator()))

tf.data.Dataset.from_generator

def gen_series1(): #生成不定长度
    i = 0
    while True:
        size = np.random.randint(0, 10)
        yield np.random.normal(size=(size,))
        i += 1

ds_series = tf.data.Dataset.from_generator(gen_series1, output_types=tf.float32, output_shapes=None)

def gen_series2(): # 定长与不定长的组合tuple
    i = 0
    while True:
        size = np.random.randint(0, 10)
        yield i, np.random.normal(size=(size,))
        i += 1


ds_series = tf.data.Dataset.from_generator(gen_series2, output_types=(tf.int32, tf.float32), output_shapes=((), (None,)))


这里的output_types是用来指定类型
output_shapes是指定shape

tf.data.TFRecordDataset

还有一种是从TFRecord文件里面读取数据的接口,TFRecords是tensorflow推荐的数据存取方式,里面每一个元素都是一个tf.train.Example,一般需要先解码才可以使用。

def parse_example(example):
    feature_dict = {
        "fixlen1": tf.io.FixedLenFeature([10], tf.int32),
        "fixlen2": tf.io.FixedLenFeature([10], tf.int32),
        "varlen": tf.io.VarLenFeature(tf.int32)
    }
    feature = tf.io.parse_single_example(example, feature_dict)
    # 如果是要输入给tf.keras,假设fixlen1,fixlen2是feature,而varlen是label
    # 这里可以做一下转换,变成tuple
    return {"fixlen1": feature['fixlen1'], "fixlen2": feature["fixlen2"]}, {"varlen": feature["varlen"]}

dataset = tf.data.TFRecordDataset(["recordfile.records"])
dataset.map(parse_example)

tf.data.TextLineDataset

dataset = tf.data.TextLineDataset(file_paths)

一行行的读取与返回

tf.data.experimental.make_csv_dataset

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])

可以指定label列名,以及选取哪几列作为特征

你可能感兴趣的:(tf.data模块)