一起来用tf.data API!(6)——批处理数据集元素

一起来用tf.data API!(6)——批处理数据集元素

  • (一)前 言
  • (二)简单的批处理
    • (1)创建Dataset
    • (2)实现批处理
  • (三)使用填充批处理张量
  • (四)总 结

(一)前 言

在上一节中我们介绍了如何使用tf.data API读取TFRecords文件,在这一节中我们将介绍如何对数据集元素进行批处理。

(二)简单的批处理

(1)创建Dataset

以range函数为例:

import tensorflow as tf

dataset = tf.data.Dataset.range(100)
# 创建一个单次迭代器
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)
# 输出:
0
1

(2)实现批处理

import tensorflow as tf

dataset = tf.data.Dataset.range(100)
# 设置每批次数据个数为4
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)
# 输出:
[0 1 2 3]
[4 5 6 7]

(三)使用填充批处理张量

上述方法适用于具有相同大小的张量。不过,很多模型(例如序列模型)处理的输入数据可能具有不同的大小(例如序列的长度不同)。为了解决这种情况,可以通过 Dataset.padded_batch() 转换来指定一个或多个会被填充的维度,从而批处理不同形状的张量。

import tensorflow as tf

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int64)], x))
# 将数据长度变为10,并相应进行补0
dataset = dataset.padded_batch(4, padded_shapes=([10]))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)
# 输出:
[[0 0 0 0 0 0 0 0 0 0]
 [1 0 0 0 0 0 0 0 0 0]
 [2 2 0 0 0 0 0 0 0 0]
 [3 3 3 0 0 0 0 0 0 0]]
[[4 4 4 4 0 0 0 0 0 0]
 [5 5 5 5 5 0 0 0 0 0]
 [6 6 6 6 6 6 0 0 0 0]
 [7 7 7 7 7 7 7 0 0 0]]

(四)总 结

在这一节中,我们介绍了使用tf.data API对数据进行批处理的操作方法,在下一节中,我们将对tf.data API对训练工作的流程的操作进行介绍,有任何的问题可在评论区留言,我会尽快回复,谢谢支持!

你可能感兴趣的:(一起来用tf.data,API!)