TensorFlow中组合训练数据函数tf.train.batch与tf.train.batch_join的输入维度与输出维度的关系

TensorFlow读出TFRecord中的数据,然后再经过预处理操作,此时需要注意:数据还是单个,而网络的输入一般以Batch为单位,因此我们需要将单个的数据组合成一个Batch,做为神经网络的输入。

TensorFlow提供组合训练数据的函数有四个:tf.train.batch(),tf.train.shuffle_batch()tf.train.batch_jointf.train.shuffle_batch_join

最近花了很久理解tf.train.batch()tf.train.batch_join的输入维度与输出维度之间的关系,真的是很头大,当我认真研读了TensorFlow的官方文档之后发现了一些玄机。官方文档的每个单词的每个字母都不能忽略,尤其是函数的参数的单复数往往藏着玄机。

首先看tf.train.batch这个函数:

tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

注意这个函数的第一个参数是tensors,是一个复数,也就是说它将你的输入看做很多个tensor组成的tensors,比如我输入一个shape为[5,4,3,2]的list,它就将这个tensors看做5个[4,3,2]的tensor。理解了这一点之后对输出维度的理解就不难了。

enqueue_many的理解:

If enqueue_many is False, tensors is assumed to represent a single example. An input tensor with shape [x, y, z] will be output as a tensor with shape [batch_size, x, y, z].

If enqueue_many is True, tensors is assumed to represent a batch of examples, where the first dimension is indexed by example, and all members of tensors should have the same size in the first dimension. If an input tensor has shape [*, x, y, z], the output will have shape [batch_size, x, y, z]. The capacity argument controls the how long the prefetching is allowed to grow the queues.

注意官方文档里的an input tensor是单数的,如果你把整个tensors看做一个tensor,那么输出的维度和你预想的就很难对上,个人理解如下:

如果enqueue_many设置为False,tensors中的每个tensor被认为代表单个样本。那么输入维度(shape)为[x,y,z]的tensor,将会输出一个维度为[batch_size,x,y,z]的张量。

如果enqueue_many设置为True,参数tensors中的每个tensor被认为是一批次的样本,其中第一维是按样本编索引的,如果输入的tensor的维度是[*,x,y,z],那么输出的张量的维度将会是[batch_size,x,y,z]。

比如输入shape为[5,4,3,2]的tensors,每个tensor的shape为[4,3,2],设batch_size为4,当enqueue_many设置为False时,每个tensor的输出的shape为[4,4,3,2],那么总的输出为[5,4,4,3,2]。当enqueue_many设置为True时,每个tensor被认为是一个batch的样本,那么它的输出为[4,3,2],总的输出为[5,4,3,2]

代码实现:

import tensorflow as tf

# shape为[5,4,3,2]的tensors
tensors = [[[[1,2],[3,4],[5,6]],[[7,8],[9,10],[11,12]],[[13,14],[15,16],[17,18]],[[19,20],[21,22],[23,24]]], [[[25,26],[27,28],[29,30]],[[31,32],[33,34],[35,36]],[[37,38],[39,40],[41,42]],[[43,44],[45,46],[47,48]]], [[[49,50],[51,52],[53,54]],[[55,56],[57,58],[59,60]],[[61,62],[63,64],[65,66]],[[67,68],[69,70],[71,72]]], [[[73,74],[75,76],[77,78]],[[79,80],[81,82],[83,84]],[[85,86],[87,88],[89,90]],[[91,92],[93,94],[95,96]]], [[[97,98],[99,100],[101,102]],[[103,104],[105,106],[107,108]],[[109,110],[111,112],[113,114]],[[115,116],[117,118],[119,120]]]]

with tf.Session() as sess:
	x1 = tf.train.batch(tensors, batch_size=4, enqueue_many=False)
	x2 = tf.train.batch(tensors, batch_size=4, enqueue_many=True)

	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess=sess,coord=coord)
	print("x1:")
	print(sess.run(x1))
	print("x2:")
	print(sess.run(x2))
	
	coord.request_stop()
	coord.join(threads)

输出:

x1:
[array([[[[ 1,  2],
     [ 3,  4],
     [ 5,  6]],

    [[ 7,  8],
     [ 9, 10],
     [11, 12]],

    [[13, 14],
     [15, 16],
     [17, 18]],

    [[19, 20],
     [21, 22],
     [23, 24]]],


   [[[ 1,  2],
     [ 3,  4],
     [ 5,  6]],

    [[ 7,  8],
     [ 9, 10],
     [11, 12]],

    [[13, 14],
     [15, 16],
     [17, 18]],

    [[19, 20],
     [21, 22],
     [23, 24]]],


   [[[ 1,  2],
     [ 3,  4],
     [ 5,  6]],

    [[ 7,  8],
     [ 9, 10],
     [11, 12]],

    [[13, 14],
     [15, 16],
     [17, 18]],

    [[19, 20],
     [21, 22],
     [23, 24]]],


   [[[ 1,  2],
     [ 3,  4],
     [ 5,  6]],

    [[ 7,  8],
     [ 9, 10],
     [11, 12]],

    [[13, 14],
     [15, 16],
     [17, 18]],

    [[19, 20],
     [21, 22],
     [23, 24]]]]), array([[[[25, 26],
     [27, 28],
     [29, 30]],

    [[31, 32],
     [33, 34],
     [35, 36]],

    [[37, 38],
     [39, 40],
     [41, 42]],

    [[43, 44],
     [45, 46],
     [47, 48]]],


   [[[25, 26],
     [27, 28],
     [29, 30]],

    [[31, 32],
     [33, 34],
     [35, 36]],

    [[37, 38],
     [39, 40],
     [41, 42]],

    [[43, 44],
     [45, 46],
     [47, 48]]],


   [[[25, 26],
     [27, 28],
     [29, 30]],

    [[31, 32],
     [33, 34],
     [35, 36]],

    [[37, 38],
     [39, 40],
     [41, 42]],

    [[43, 44],
     [45, 46],
     [47, 48]]],


   [[[25, 26],
     [27, 28],
     [29, 30]],

    [[31, 32],
     [33, 34],
     [35, 36]],

    [[37, 38],
     [39, 40],
     [41, 42]],

    [[43, 44],
     [45, 46],
     [47, 48]]]]), array([[[[49, 50],
     [51, 52],
     [53, 54]],

    [[55, 56],
     [57, 58],
     [59, 60]],

    [[61, 62],
     [63, 64],
     [65, 66]],

    [[67, 68],
     [69, 70],
     [71, 72]]],


   [[[49, 50],
     [51, 52],
     [53, 54]],

    [[55, 56],
     [57, 58],
     [59, 60]],

    [[61, 62],
     [63, 64],
     [65, 66]],

    [[67, 68],
     [69, 70],
     [71, 72]]],


   [[[49, 50],
     [51, 52],
     [53, 54]],

    [[55, 56],
     [57, 58],
     [59, 60]],

    [[61, 62],
     [63, 64],
     [65, 66]],

    [[67, 68],
     [69, 70],
     [71, 72]]],


   [[[49, 50],
     [51, 52],
     [53, 54]],

    [[55, 56],
     [57, 58],
     [59, 60]],

    [[61, 62],
     [63, 64],
     [65, 66]],

    [[67, 68],
     [69, 70],
     [71, 72]]]]), array([[[[73, 74],
     [75, 76],
     [77, 78]],

    [[79, 80],
     [81, 82],
     [83, 84]],

    [[85, 86],
     [87, 88],
     [89, 90]],

    [[91, 92],
     [93, 94],
     [95, 96]]],


   [[[73, 74],
     [75, 76],
     [77, 78]],

    [[79, 80],
     [81, 82],
     [83, 84]],

    [[85, 86],
     [87, 88],
     [89, 90]],

    [[91, 92],
     [93, 94],
     [95, 96]]],


   [[[73, 74],
     [75, 76],
     [77, 78]],

    [[79, 80],
     [81, 82],
     [83, 84]],

    [[85, 86],
     [87, 88],
     [89, 90]],

    [[91, 92],
     [93, 94],
     [95, 96]]],


   [[[73, 74],
     [75, 76],
     [77, 78]],

    [[79, 80],
     [81, 82],
     [83, 84]],

    [[85, 86],
     [87, 88],
     [89, 90]],

    [[91, 92],
     [93, 94],
     [95, 96]]]]), array([[[[ 97,  98],
     [ 99, 100],
     [101, 102]],

    [[103, 104],
     [105, 106],
     [107, 108]],

    [[109, 110],
     [111, 112],
     [113, 114]],

    [[115, 116],
     [117, 118],
     [119, 120]]],


   [[[ 97,  98],
     [ 99, 100],
     [101, 102]],

    [[103, 104],
     [105, 106],
     [107, 108]],

    [[109, 110],
     [111, 112],
     [113, 114]],

    [[115, 116],
     [117, 118],
     [119, 120]]],


   [[[ 97,  98],
     [ 99, 100],
     [101, 102]],

    [[103, 104],
     [105, 106],
     [107, 108]],

    [[109, 110],
     [111, 112],
     [113, 114]],

    [[115, 116],
     [117, 118],
     [119, 120]]],


   [[[ 97,  98],
     [ 99, 100],
     [101, 102]],

    [[103, 104],
     [105, 106],
     [107, 108]],

    [[109, 110],
     [111, 112],
     [113, 114]],

    [[115, 116],
     [117, 118],
     [119, 120]]]])]
x2:
[array([[[ 1,  2],
    [ 3,  4],
    [ 5,  6]],

   [[ 7,  8],
    [ 9, 10],
    [11, 12]],

   [[13, 14],
    [15, 16],
    [17, 18]],

   [[19, 20],
    [21, 22],
    [23, 24]]]), array([[[25, 26],
    [27, 28],
    [29, 30]],

   [[31, 32],
    [33, 34],
    [35, 36]],

   [[37, 38],
    [39, 40],
    [41, 42]],

   [[43, 44],
    [45, 46],
    [47, 48]]]), array([[[49, 50],
    [51, 52],
    [53, 54]],

   [[55, 56],
    [57, 58],
    [59, 60]],

   [[61, 62],
    [63, 64],
    [65, 66]],

   [[67, 68],
    [69, 70],
    [71, 72]]]), array([[[73, 74],
    [75, 76],
    [77, 78]],

   [[79, 80],
    [81, 82],
    [83, 84]],

   [[85, 86],
    [87, 88],
    [89, 90]],

   [[91, 92],
    [93, 94],
    [95, 96]]]), array([[[ 97,  98],
    [ 99, 100],
    [101, 102]],

   [[103, 104],
    [105, 106],
    [107, 108]],

   [[109, 110],
    [111, 112],
    [113, 114]],

   [[115, 116],
    [117, 118],
    [119, 120]]])]

再输出看一下x1,x2的shape:

In [61]: x1
Out[61]:
[,
 ,
 ,
 ,
 ]

In [62]: x2
Out[62]:
[,
 ,
 ,
 ,
 ]

可以看到输出的维度和我们的预想是一致的。

接下来看一下tf.train.batch_join这个函数:

tf.train.batch_join(
    tensors_list,
    batch_size,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

这个函数一般用于多线程多文件的情况,所以它的第一个参数tensors_list,是一个由多个tensors组成的list。比如输入shape为[5,4,3,2]的tensors_list,那么它就是由5个shape为[4,3,2]的tensors组成的list。每个tensors又是由shape为[3,2]的tensor组成的。

它的输出的维度其实是由每个tensors的维度决定的,方式与tf.train.batch一样,对于tensors_list,我的理解是由于此函数用于多文件的情况,每个tensors就相当于一个文件,最终的输出的维度还是由tensors的维度决定的。

比如shape为[5,4,3,2]的tensors_list,tensors的shape为[4,3,2],那么每个tensor的维度就为[3,2],设batch_size为4,那么当enqueue_many为False时,输出的维度就是[4,4,3,2],当enqueue_many为True,输出的维度就是[4,4,2]。

代码示例:

import tensorflow as tf	

# shape为[5,4,3,2]的tensors_list
tensors_list = [[[[1,2],[3,4],[5,6]],[[7,8],[9,10],[11,12]],[[13,14],[15,16],[17,18]],[[19,20],[21,22],[23,24]]], [[[25,26],[27,28],[29,30]],[[31,32],[33,34],[35,36]],[[37,38],[39,40],[41,42]],[[43,44],[45,46],[47,48]]], [[[49,50],[51,52],[53,54]],[[55,56],[57,58],[59,60]],[[61,62],[63,64],[65,66]],[[67,68],[69,70],[71,72]]], [[[73,74],[75,76],[77,78]],[[79,80],[81,82],[83,84]],[[85,86],[87,88],[89,90]],[[91,92],[93,94],[95,96]]], [[[97,98],[99,100],[101,102]],[[103,104],[105,106],[107,108]],[[109,110],[111,112],[113,114]],[[115,116],[117,118],[119,120]]]]

with tf.Session() as sess:
	x3 = tf.train.batch_join(tensors_list, batch_size=4, enqueue_many=False)
	x4 = tf.train.batch_join(tensors_list, batch_size=4, enqueue_many=True)

	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess=sess,coord=coord)
	print("x3:")
	print(sess.run(x3))
	print("x4:")
	print(sess.run(x4))
	
	coord.request_stop()
	coord.join(threads)

输出:

x3:
[array([[[ 97,  98],
        [ 99, 100],
        [101, 102]],

       [[  1,   2],
        [  3,   4],
        [  5,   6]],

       [[ 25,  26],
        [ 27,  28],
        [ 29,  30]],

       [[ 49,  50],
        [ 51,  52],
        [ 53,  54]]]), array([[[103, 104],
        [105, 106],
        [107, 108]],

       [[  7,   8],
        [  9,  10],
        [ 11,  12]],

       [[ 31,  32],
        [ 33,  34],
        [ 35,  36]],

       [[ 55,  56],
        [ 57,  58],
        [ 59,  60]]]), array([[[109, 110],
        [111, 112],
        [113, 114]],

       [[ 13,  14],
        [ 15,  16],
        [ 17,  18]],

       [[ 37,  38],
        [ 39,  40],
        [ 41,  42]],

       [[ 61,  62],
        [ 63,  64],
        [ 65,  66]]]), array([[[115, 116],
        [117, 118],
        [119, 120]],

       [[ 19,  20],
        [ 21,  22],
        [ 23,  24]],

       [[ 43,  44],
        [ 45,  46],
        [ 47,  48]],

       [[ 67,  68],
        [ 69,  70],
        [ 71,  72]]])]
x4:
[array([[ 1,  2],
       [ 3,  4],
       [ 5,  6],
       [25, 26]]), array([[ 7,  8],
       [ 9, 10],
       [11, 12],
       [31, 32]]), array([[13, 14],
       [15, 16],
       [17, 18],
       [37, 38]]), array([[19, 20],
       [21, 22],
       [23, 24],
       [43, 44]])]

再看一下x3,x4的shape:

In [68]: x3
Out[68]:
[,
 ,
 ,
 ]

In [69]: x4
Out[69]:
[,
 ,
 ,
 ]

可以看到和我们预想的一致。

这次我只讨论了这两个函数的输出维度与输入维度的关系问题。至于具体的值,大家可以自己参考官方文档,如果有时间我也会更新。

你可能感兴趣的:(TensorFlow)