tf.train.batch与tf.train.shuffle_batch的作用都是从队列中读取数据,它们的区别是是否随机打乱数据来读取。
tf.train.batch 是按顺序读取队列中的数据
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
tensors:一个列表或字典的tensor用来进行入队
batch_size:每次从队列中获取出队数据的数量
num_threads:用来控制入队tensors线程的数量,如果num_threads大于1,则batch操作将是非确定性的,输出的batch可能会乱序
capacity:一个整数,用来设置队列中元素的最大数量
enqueue_many:在tensors中的张量是否是单个样本,若为False,则认为tensors代表一个样本.输入张量形状为[x, y, z]时,则输出张量形状为[batch_size, x, y, z],若为True,则认为tensors代表一批样本,其中第一个维度为样本的索引,并且所有成员tensors在第一维中应具有相同大小.若输入张量形状为[*, x, y, z],则输出张量的形状为[batch_size, x, y, z]
shapes:每个样本的shape,默认是tensors的shape
dynamic_pad:为True时允许输入变量的shape,出队后会自动填补维度,来保持与batch内的shapes相同
allow_smaller_final_batch:为True队列中的样本数量小于batch_size时,出队的数量会以最终遗留下来的样本进行出队,如果为Flalse,小于batch_size的样本不会做出队处理
shared_name:如果设置,则队列将在多个会话中以给定名称共享
name:操作的名字
tf.train.batch示例:
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
# 切片
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 按顺序读取队列中的数据
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
按顺序读取队列中的数据,输出如下:
[ 0.12363787 0.53772059] 0
[ 0.92259879 0.59163142] 1
[ 0.43266023 0.86109054] 2
[ 0.56078746 0.06636034] 3
[ 0.76695322 0.60522699] 4
tf.train.shuffle_batch是将队列中的数据随机打乱后再读取出来
tf.train.shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
可以看出,跟tf.train.batch的参数是一样的,只是这里多了个seed和min_after_dequeue,其中seed表示随机数的种子,min_after_dequeue是出队后队列中元素的最小数量,用于确保元素的混合级别,这个参数一定要比capacity小。
tf.train.shuffle_batch示例:
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将队列中数据打乱后再读取出来
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
# print(image_batch_v.shape, label_batch_v[j])
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
将队列中数据随机打乱后再读取出来,输出如下:
[ 0.66230287 0.54226019] 0
[ 0.92299829 0.39165142] 1
[ 0.32025623 0.86109054] 2
[ 0.95208746 0.09522334] 3
[ 0.32601722 0.65002599] 4
参考:https://blog.csdn.net/akadiao/article/details/79645221