

What is the proper use of Tensorflow dataset prefetch and cache options?


"When the GPU is working on forward / backward propagation on the current batch, we want the CPU to process the next batch of data so that it is immediately ready. As the most expensive part of the computer, we want the GPU to be fully used all the time during training. We call this consumer / producer overlap, where the consumer is the GPU and the producer is the CPU.

With tf.data, you can do this with a simple call to dataset.prefetch(1) at the end of the pipeline (after batching). This will always prefetch one batch of data and make sure that there is always one ready.

In some cases, it can be useful to prefetch more than one batch. For instance if the duration of the preprocessing varies a lot, prefetching 10 batches would average out the processing time over 10 batches, instead of sometimes waiting for longer batches.

To give a concrete example, suppose than 10% of the batches take 10s to compute, and 90% take 1s. If the GPU takes 2s to train on one batch, by prefetching multiple batches you make sure that we never wait for these rare longer batches."

I’m not quite sure how to determine processing time of each batch but that’s the next step. If your batches are roughly taking the same amount of time to process then I believe prefetch(batch_size=1) should suffice as your GPU wouldn’t be waiting for the cPU to finish processing a computationally expensive batch.



3.某些情况下,prefetch多个batch可能很有用。例如,如果预处理的持续时间变化很大,预取 10 个批次将平均处理 10 个批次的处理时间,而不是有时等待更长的批次。

4.举一个具体的例子,假设 10% 的批次需要 10 秒来计算,90% 需要 1 秒。如果 GPU 需要 2 秒来训练一个批次,那么通过prefetch多个批次,可以确保我们永远不会等待这些罕见的较长批次。


