使用自己的数据训练网络时,需要用到两个函数:tf.train.slice_input_producer、tf.train.batch和两个类tf.train.Coordinator和tf.QueueRunner。
参考博客:tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数和tensorflow中协调器 tf.train.Coordinator 和入队线程启动器 tf.train.start_queue_runners
具体流程如下:
1、调用 tf.train.slice_input_producer,从本地文件里抽取tensor,准备放入Filename Queue(文件名队列)
2、调用 tf.train.batch,从文件名队列中提取tensor,使用单个或多个线程,准备放入文件队列
3、调用 tf.train.Coordinator() 来创建一个线程协调器,用来管理之后在Session中启动的所有线程
4、调用tf.train.start_queue_runners, 启动入队线程,把文件读入Filename Queue中,一般情况下,系统有多少个核,就会启动多少个入队线程(入队具体使用多少个线程在tf.train.batch中定义)
5、文件从 Filename Queue中读入内存队列的操作不用手动执行,由tf自动完成
6、调用sess.run 执行计算
7、使用 coord.should_stop()查询是否应该终止所有线程,当文件队列中的所有文件都已经读取出列的时候,会抛出一个 OutofRangeError 的异常,这时候就应该停止Sesson中的所有线程了
8、使用coord.request_stop()来发出终止所有线程的命令,使用coord.join(threads)把线程加入主线程,等待threads结束
def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)
tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列
需要指出的是,该函数需要与入队线程启动器tf.train.start_queue_runners配合使用,只有通过start_queue_runners启动线程队列才能输出数据,否则一直处于阻塞状态,程序无法继续运行(由于本人对于其内部机制的理解并不透彻,目前只能解释到这里…但二者必须配合使用,否则队列中的数据无法输出,程序会出现阻塞不能正常运行!)
def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None)
tf.train.batch是一个tensor队列生成器,作用是按照给定的tensor顺序,把batch_size个tensor推送到文件队列,作为训练一个batch的数据,等待tensor出队执行计算
TensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。
TensorFlow提供了两个类来实现对Session中多线程的管理:tf.Coordinator和 tf.QueueRunner,这两个类往往一起使用。
Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners , 只有调用 tf.train.start_queue_runners 之后,才会真正把tensor推入内存序列中,供计算单元调用,否则会由于内存序列为空,数据流图会处于一直等待状态
下面直接copy代码:
1、生成数据batch
def get_batch(data_root, img_size, batch_size, n_classes):
# 1 获取图像路径列表和标签列表
img_list, lab_list = [], []
for dir in os.listdir(data_root):
file_dir = os.path.join(data_root, dir)
for file in os.listdir(file_dir):
img_list.append(os.path.join(file_dir, file))
lab_list.append(int(dir[0: 3]) - 1)
# 2 使用tf.cast转换为tensor
img_tensor, lab_tensor = tf.cast(img_list, tf.string), tf.cast(lab_list, tf.int32)
# lab_tensor = tf.one_hot(lab_tensor, depth=n_classes)
# 3 [image_tensor lab_tensor]组成tensor_list,并调用tf.train.slice_input_producer将tensor_list送入队列
input_queue = tf.train.slice_input_producer([img_tensor, lab_tensor], shuffle=False)
# 4 利用tf.read_file和tf.image读取图像数据
x = input_queue[0]
x = tf.read_file(x)
x = tf.image.decode_jpeg(x, channels=3)
x = tf.image.resize_images(x, size=(img_size, img_size))
y = input_queue[1]
# 5 调用tf.train.batch生成batch
x_batch, y_batch = tf.train.batch([x, y], batch_size=batch_size)
return x_batch, y_batch
2、定义网络结构、损失函数、优化器
3、训练
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# get_batch函数的调用一定要在tf.train.start_queue_runners队列开启之前,否则会出现程序阻塞
x_batch, y_batch = get_batch(train_root, img_size=IMG_SIZE, batch_size=BATCH_SIZE, n_classes=N_CLASSES)
# 开启协调器,用于管理多线程,在TF中tf.Coordinator和 tf.QueueRunner通常一起使用
coord = tf.train.Coordinator()
# 一定要调用入队线程启动器tf.train.start_queue_runners启动队列填充!!!
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(MAX_ITER):
if coord.should_stop():
break
# 获取batch_size个训练数据和标签
data, label = sess.run([x_batch, y_batch])
# 进行训练,train_op需自行定义
sess.run(train_op, feed_dict={xs: data, ys: label, keep_rate: 0.5})
t_loss, t_accu = sess.run([loss, accu], feed_dict={xs: data, ys: label, keep_rate: 1})
print('step:%d loss:%f accu:%f'%(step, t_loss, t_accu))
if step % 10 == 0:
print('val acc:%f'%get_v_acc(val_root))
except tf.errors.OutOfRangeError:
print('Done Training')
finally:
# 协调器coord发出所有线程终止信号
coord.request_stop()
# 把开启的线程加入主线程,等待threads结束
coord.join(threads)
# 关闭session
sess.close()
可以通过下面的代码进一步理解这一过程:
data, label = sess.run([x_batch, y_batch])
for i in range(BATCH_SIZE):
print(data[i], label[i])
input('pause')
可以看出,经过上述几个函数,TF每次可以自动从队列中取出batch_size个数据供调用者使用,从而可以达到批量训练自己数据的目的