陈伟@航天科技智慧城市研究院 [email protected]
tf.train.batch与tf.train.shuffle_batch的作用都是从队列中读取数据,它们的区别是是否随机打乱数据来读取。
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5, 2]) # 5x2的矩阵
print(images)
label = np.asarray(range(0, 5)) # [0, 1, 2, 3, 4]
print(label)
# 将数组转换为张量
images = tf.cast(images, tf.float32)
print(images)
label = tf.cast(label, tf.int32)
print(label)
# 切片
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 按顺序读取队列中的数据
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
输出
[0.15518834 0.4924818 ]
0
[0.3907916 0.05013292]
1
[0.41328526 0.802318 ]
2
[0.43541858 0.9412442 ]
3
[0.16782863 0.6347318 ]
4
参考代码移步Github
tf.train.shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
可以看出,跟tf.train.batch的参数是一样的,只是这里多了个seed和min_after_dequeue,其中seed表示随机数的种子,min_after_dequeue是出队后队列中元素的最小数量,用于确保元素的混合级别,这个参数一定要比capacity小。
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5, 2])
label = np.asarray(range(0, 5))
# 列表转成张量
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
# 输入张量列表
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将队列中数据打乱后再读取出来
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)
with tf.Session() as sess:
# 线程的协调器
coord = tf.train.Coordinator()
# 开始在图表中收集队列运行器
threads = tf.train.start_queue_runners(sess, coord)
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
for j in range(5):
# print(image_batch_v.shape, label_batch_v[j])
print(image_batch_v[j]),
print(label_batch_v[j])
# 请求线程结束
coord.request_stop()
# 等待线程终止
coord.join(threads)
输出
[0.93350357 0.9003149 ]
0
[0.7407439 0.896775 ]
1
[0.6358515 0.69127077]
2
[0.96927387 0.7181145 ]
3
[0.93350357 0.9003149 ]
0
代码已经上传至Github