Dataset.shard记录

Dataset.shard官方文档

Distributed TensorFlow

Here is an example with 3 workers.

dataset = tf.data.Dataset.range(6)

dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)

iterator = dataset.make_one_shot_iterator()

res = iterator.get_next()

#Suppose you have 3 workers in total

    with tf.Session() as sess:

        for i in range(2):

            print(sess.run(res))

We will have the output:

0, 3 on worker 0

1, 4 on worker 1

2, 5 on worker 2

你可能感兴趣的:(Dataset.shard记录)