先放关键代码:
#i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
处理列表
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
处理矩阵
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
数据文件test.txt内容:
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35
main.py的内容
import tensorflow as tf
import codecs
BATCH_SIZE = 6
NUM_EXPOCHES = 5
def input_producer():
array = codecs.open("test.txt").readlines()
array = map(lambda line: line.strip(), array)
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
return inputs
class Inputs(object):
def __init__(self):
self.inputs = input_producer()
def main(*args, **kwargs):
inputs = Inputs()
init = tf.group(tf.initialize_all_variables(),
tf.initialize_local_variables())
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(init)
try:
index = 0
while not coord.should_stop() and index<10:
datalines = sess.run(inputs.inputs)
index += 1
print("step: %d, batch data: %s" % (index, str(datalines)))
except tf.errors.OutOfRangeError:
print("Done traing:-------Epoch limit reached")
except KeyboardInterrupt:
print("keyboard interrput detected, stop training")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
del sess
if __name__ == "__main__":
main()
输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached
如果range_input_producer去掉参数num_epochs=1,则输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
[[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
tf.strided_slice即矩阵切片,参考链接https://www.jianshu.com/p/58aa9c1fb8a9
tf.strided_slice( input_, begin, end ) 提取张量的一部分
返回的张量中,元素的个数:end与begin对应元素做差再相乘,结果取绝对值
下面以官方的三个示例为例进行解释,t是一个3*2*3的张量
在tensorflow PTB例子中代码:
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
#tf.strided_slice()是切片。
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i * num_steps + 1],
[batch_size, (i + 1) * num_steps + 1])
y.set_shape([batch_size, num_steps])
特别注意:
tf.train.range_input_producer()生成数据队列,必须放在开启多线程之前。
代码:
import reader
import tensorflow as tf
# 数据路径
DATA_PATH = 'simple-examples/data/'
# 读取原始数据
train_data, valid_data, test_data, _ = reader.ptb_raw_data(DATA_PATH)
# 将数据组织成batch大小为4,截断长度为5的数据组,要放在开启多线程之前
batch = reader.ptb_producer(train_data, 4, 5)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# 开启多线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 读取前两个batch,其中包括每个时刻的输入和对应的答案,ptb_producer()会自动迭代
for i in range(2):
x, y = sess.run(batch)
print('x:', x)
print('y:', y)
# 关闭多线程
coord.request_stop()
coord.join(threads)
运行结果如下:
x: [[9970 9971 9972 9974 9975]
[ 332 7147 328 1452 8595]
[1969 0 98 89 2254]
[ 3 3 2 14 24]]
y: [[9971 9972 9974 9975 9976]
[7147 328 1452 8595 59]
[ 0 98 89 2254 0]
[ 3 2 14 24 198]]
x: [[9976 9980 9981 9982 9983]
[ 59 1569 105 2231 1]
[ 0 312 1641 4 1063]
[ 198 150 2262 10 0]]
y: [[9980 9981 9982 9983 9984]
[1569 105 2231 1 895]
[ 312 1641 4 1063 8]
[ 150 2262 10 0 507]]