dataset batch使用小坑

背景

使用dataset进行数据管道化处理时,通常我们会加上batch(batch_size)来获取批量样本。这里有个容易忽视的点,batch本身还提供了一个参数drop_remaindar,用于标示是否对于最后一个batch如果数据量达不到batch_size时保留还是抛弃。本次的小坑就是由于这个参数导致的。

案例

show me the code:

  with tf.name_scope('input'):
    dataset = tf.data.Dataset.from_tensor_slices(files).interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(10), cycle_length=num_preprocess_threads)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
    dataset = dataset.map(lambda x: _decode(x, type), num_parallel_calls=2)
    dataset = dataset.shuffle(buffer_size=batch_size * 10)
    dataset = dataset.prefetch(buffer_size=1000)

这是一段简单的使用dataset来解析tfrecord的代码,为了方便在创建dataset时,就将所有的数据集的batch_size设为了相同的值。那就导致在数据消费的时候,最后一个batch的数量达不到batch_size,所以这里我们将drop_remainder设为true,运行出错。
经过排查后发现,tf1.10之后的版本才支持这种方式,而之前的版本只能使用tf.contrib.data.batch_and_drop_remainder(batch_size)。

备注:小坑记录下,做留念

你可能感兴趣的:(dataset batch使用小坑)