tf.data是tensorflow提供的用来构建模型输入流水线的模块,集成了map,reduce,batch,shuffle等功能,使用起来比较方便,最佳的自然去看官网链接,这里只是我的学习记录。
tf.data.Dataset.from_tensor_slices
- 传入一维的list,输出的是scalar
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> list(dataset.as_numpy_iterator())
[1, 2, 3]
传入二维的tensor,输出一维的tensor
>>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
>>> list(dataset.as_numpy_iterator())
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
传入字典,字典的value对应的是tensor,输出的也是字典,key不变,value正常切割
>>> # Dictionary structure is also preserved.
>>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
>>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
... {'a': 2, 'b': 4}]
True
传入tuple构成的数据,相当于tuple内部的元素依次切割,并然后在组合起来
>>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
>>> list(dataset.as_numpy_iterator())
[(1, 3, 5), (2, 4, 6)]
# 这种格式在tf.keras的model.fit里面很常用,尤其是对于有多个输入和输出的时候,可以用key去指定,这里前一个字典代表了feature,后一个字典代表了label
>>> dataset = tf.data.Dataset.from_tensor_slices(({"a": [1, 2]}, {"b": [3, 4]}))
>>> list(dataset.as_numpy_iterator())
[({'a': 1}, {'b': 3}), ({'a': 2}, {'b': 4})]
注意:这里不支持non-rectangular形式的输入tensor,比如这种就不行,然而使用from_generator可以接受不一样的tensor
dataset = tf.data.Dataset.from_tensor_slices([[1], [2,3]])
print(list(dataset.as_numpy_iterator()))
tf.data.Dataset.from_generator
def gen_series1(): #生成不定长度
i = 0
while True:
size = np.random.randint(0, 10)
yield np.random.normal(size=(size,))
i += 1
ds_series = tf.data.Dataset.from_generator(gen_series1, output_types=tf.float32, output_shapes=None)
def gen_series2(): # 定长与不定长的组合tuple
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
ds_series = tf.data.Dataset.from_generator(gen_series2, output_types=(tf.int32, tf.float32), output_shapes=((), (None,)))
这里的output_types是用来指定类型
output_shapes是指定shape
tf.data.TFRecordDataset
还有一种是从TFRecord文件里面读取数据的接口,TFRecords是tensorflow推荐的数据存取方式,里面每一个元素都是一个tf.train.Example,一般需要先解码才可以使用。
def parse_example(example):
feature_dict = {
"fixlen1": tf.io.FixedLenFeature([10], tf.int32),
"fixlen2": tf.io.FixedLenFeature([10], tf.int32),
"varlen": tf.io.VarLenFeature(tf.int32)
}
feature = tf.io.parse_single_example(example, feature_dict)
# 如果是要输入给tf.keras,假设fixlen1,fixlen2是feature,而varlen是label
# 这里可以做一下转换,变成tuple
return {"fixlen1": feature['fixlen1'], "fixlen2": feature["fixlen2"]}, {"varlen": feature["varlen"]}
dataset = tf.data.TFRecordDataset(["recordfile.records"])
dataset.map(parse_example)
tf.data.TextLineDataset
dataset = tf.data.TextLineDataset(file_paths)
一行行的读取与返回
tf.data.experimental.make_csv_dataset
titanic_batches = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=4,
label_name="survived", select_columns=['class', 'fare', 'survived'])
可以指定label列名,以及选取哪几列作为特征