tf.stack() 使用分析

 最近在看文章的时候,发现对stack方法有些迷惑,所以在这里填个坑:

import tensorflow as tf

vector = [[[tf.range(1, 6), tf.range(6, 11)], [tf.range(11, 16), tf.range(16, 21)]],
          [[tf.range(21, 26), tf.range(26, 31)], [tf.range(31, 36), tf.range(36, 41)]]]

print(tf.stack(vector, axis=0))
"""
[[[[ 1  2  3  4  5]
   [ 6  7  8  9 10]]

  [[11 12 13 14 15]
   [16 17 18 19 20]]]


 [[[21 22 23 24 25]
   [26 27 28 29 30]]

  [[31 32 33 34 35]
   [36 37 38 39 40]]]]
"""
print(tf.stack(vector, axis=1))
"""
[[[[ 1  2  3  4  5]
   [ 6  7  8  9 10]]

  [[21 22 23 24 25]
   [26 27 28 29 30]]]


 [[[11 12 13 14 15]
   [16 17 18 19 20]]

  [[31 32 33 34 35]
   [36 37 38 39 40]]]]
"""
print(tf.stack(vector, axis=2))
"""
[[[[ 1  2  3  4  5]
   [21 22 23 24 25]]

  [[ 6  7  8  9 10]
   [26 27 28 29 30]]]


 [[[11 12 13 14 15]
   [31 32 33 34 35]]

  [[16 17 18 19 20]
   [36 37 38 39 40]]]]
"""
print(tf.stack(vector, axis=3))
"""
[[[[ 1 21]
   [ 2 22]
   [ 3 23]
   [ 4 24]
   [ 5 25]]

  [[ 6 26]
   [ 7 27]
   [ 8 28]
   [ 9 29]
   [10 30]]]


 [[[11 31]
   [12 32]
   [13 33]
   [14 34]
   [15 35]]

  [[16 36]
   [17 37]
   [18 38]
   [19 39]
   [20 40]]]]
"""
print(tf.stack(vector, axis=-1))
"""
[[[[ 1 21]
   [ 2 22]
   [ 3 23]
   [ 4 24]
   [ 5 25]]

  [[ 6 26]
   [ 7 27]
   [ 8 28]
   [ 9 29]
   [10 30]]]


 [[[11 31]
   [12 32]
   [13 33]
   [14 34]
   [15 35]]

  [[16 36]
   [17 37]
   [18 38]
   [19 39]
   [20 40]]]]
"""

从例子可以看出,对于一个 2*2*2*5 的张量组,最浅显的理解:

  • 如果 stack 从0维合并,可以将输入看作 [ab] 进行合并
  • 如果 stack 从1维合并,可以将输入看作 [ [a, b] , [c, d] ] 进行合并
  • 如果 stack 从2维合并,可以将输入看作 [ [ [a, b], [c, d] ], [ [e, f], [g, h] ] ] 进行合并
  • 如果 stack 从3维合并,可以将输入看作
    [ [ [ [a1, a2, a3, a4, a5], [b1, b2, b3, b4, b5] ],
        [ [c1, c2, c3, c4, c5], [d1, d2, d3, d4, d5] ] ],
      [ [ [e1, e2, e3, e4, e5], [f1, f2, f3, f4, f5] ],
        [ [g1, g2, g3, g4, g5], [h1, h2, h3, h4, h5] ] ] ]  进行合并。

也就说,随着axis的增加,stack方法合并的维度就在增加,从表层维度合并到更深层次的维度。这样就比较好理解stack方法了。

你可能感兴趣的:(#,tensorflow学习,tensorflow,机器学习)