f.train.batch与tf.train.shuffle_batch的作用都是从队列中读取数据.
tf.train.batch() 按顺序读取队列中的数据
队列中的数据始终是一个有序的队列.队头一直按顺序补充,队尾一直按顺序出队.
参数:
若enqueue_many为False,则认为tensors代表一个示例.输入张量形状为[x, y, z]时,则输出张量形状为[batch_size, x, y, z].
若enqueue_many为True,则认为tensors代表一批示例,其中第一个维度为示例的索引,并且所有成员tensors在第一维中应具有相同大小.若输入张量形状为[*, x, y, z],则输出张量的形状为[batch_size, x, y, z].
tf.train.batch()示例
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
# 切片
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 按顺序读取队列中的数据
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
按顺序读取队列中的数据,输出:
[ 0.05013787 0.53446019] 0
[ 0.91189879 0.69153142] 1
[ 0.39966023 0.86109054] 2
[ 0.85078746 0.05766034] 3
[ 0.71261722 0.60514599] 4
tf.train.shuffle_batch() 将队列中数据打乱后再读取出来.
函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.
其他与tf.train.batch()类似.
tf.train.shuffle_batch示例
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将队列中数据打乱后再读取出来
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
# print(image_batch_v.shape, label_batch_v[j])
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
将队列中数据打乱后再读取出来,输出:
[ 0.08383977 0.75228119] 1
[ 0.03610427 0.53876138] 0
[ 0.33962703 0.47629601] 3
[ 0.21824744 0.84182823] 4
[ 0.8376292 0.52254623] 2