tf.gather 2018-06-06

1、当indices=[0,2],axis=0

input =[ [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]],

         [[[7, 7, 7], [8, 8, 8]],
         [[9, 9, 9], [10, 10, 10]],
         [[11, 11, 11], [12, 12, 12]]],

        [[[13, 13, 13], [14, 14, 14]],
         [[15, 15, 15], [16, 16, 16]],
         [[17, 17, 17], [18, 18, 18]]]
         ]

print(tf.shape(input))
with tf.Session() as sess:
    output=tf.gather(input, [0,2],axis=0)#其实默认axis=0
    print(sess.run(output))

输出结果
[[[[ 1 1 1]
[ 2 2 2]]

[[ 3 3 3]
[ 4 4 4]]

[[ 5 5 5]
[ 6 6 6]]]

[[[13 13 13]
[14 14 14]]

[[15 15 15]
[16 16 16]]

[[17 17 17]
[18 18 18]]]]

解释:

右中括号就暂时不理会他先了。
第一个[ 是列表语法需要的括号,剩下的最里面的三个[[[是axis=0需要搜寻的中括号。这里一共有3个[[[。
indices的[0,2]即取第0个[[[和第2个[[[,也就是第0个和第2个三维立体。

2、当indices=[0,2],axis=1

input =[ [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]],

         [[[7, 7, 7], [8, 8, 8]],
         [[9, 9, 9], [10, 10, 10]],
         [[11, 11, 11], [12, 12, 12]]],

        [[[13, 13, 13], [14, 14, 14]],
         [[15, 15, 15], [16, 16, 16]],
         [[17, 17, 17], [18, 18, 18]]]
         ]
print(tf.shape(input))
with tf.Session() as sess:
    output=tf.gather(input, [0,2],axis=1)#默认axis=0
    print(sess.run(output))

输出结果
[[[[ 1 1 1]
[ 2 2 2]]

[[ 5 5 5]
[ 6 6 6]]]

[[[ 7 7 7]
[ 8 8 8]]

[[11 11 11]
[12 12 12]]]

[[[13 13 13]
[14 14 14]]

[[17 17 17]
[18 18 18]]]]

解释:

第一个[ 是列表语法需要的括号,先把这个干扰去掉,剩下的所有内侧的 [[ 是axis=1搜索的中括号。
然后[0,2]即再取每个[[[体内的第0个[[和第2个[[,也就是去每个三维体的第0个面和第2个面。

你可能感兴趣的:(tf.gather 2018-06-06)