Python学习(11):tf.train.shuffle_batch

参考:怎么理解tensorflow中tf.train.shuffle_batch()函数?

TensorFlow学习--tf.train.batch与tf.train.shuffle_batch

(1)背景:

(2)用法

tf.train.shuffle_batch() 将队列中数据打乱后再读取出来.

函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.

tensors:排列的张量或词典.

batch_size:从队列中提取新的批量大小.

 capacity:队列中元素的最大数量.

 min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别.

 num_threads:线程数量.

 seed:队列内随机乱序的种子值.

 enqueue_many:tensors中的张量是否都是一个例子.

  shapes:每个示例的形状.(可选项)

 allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)

  shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)

   name:操作的名称.(可选项)


(2)功能:

     Creates batches by randomly shuffling tensors,但需要注意的是它是一种图运算,要跑在sess.run()里。具体地,

This function adds the following to the current Graph:

在运行这个函数时它会在当前图上创建如下的东西:

A shuffling queue into which tensors from tensors are enqueued.

一个乱序的队列,进队的正是传入的tensors

A dequeue_many operation to create batches from the queue.

一个dequeue_many的操作从队列中推出成batch的tensor

A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.

一个QueueRunner的线程,正是这个线程将传入的数据推进队列中.

把数据放在队列里有很多好处,可以完成训练数据和测试数据的解耦,同时有利于写成分布式训练(个人理解),但需要注意的是在取数据的时候,容易造成堵塞的情况.

这时候,应该需要截获超时异常来强制停止线程.

你可能感兴趣的:(Python学习(11):tf.train.shuffle_batch)