Tensorflow-API :tf.stack()和tf.unstack()

tf.stack():矩阵拼接

tf.unstack():矩阵分解

import tensorflow as tf

a = tf.reshape(tf.range(0, 12), [3, 4])
b = tf.reshape(tf.range(100, 112), [3, 4])
# 按第0维拼接
stack0 = tf.stack([a, b], axis=0)
# 按第1维拼接
stack1 = tf.stack([a, b], axis=1)
# 按第0维分解
unstack00 = tf.unstack(stack0, axis=0)
# 按第1维分解
unstack01 = tf.unstack(stack0, axis=1)

with tf.Session() as sess:
    print(sess.run(stack0))
    print('-' * 30)
    print(sess.run(stack1))
    print('-' * 30)
    print(sess.run(unstack00))
    print('-' * 30)
    print(sess.run(unstack01))

[[[  0   1   2   3]
  [  4   5   6   7]
  [  8   9  10  11]]

 [[100 101 102 103]
  [104 105 106 107]
  [108 109 110 111]]]
------------------------------
[[[  0   1   2   3]
  [100 101 102 103]]

 [[  4   5   6   7]
  [104 105 106 107]]

 [[  8   9  10  11]
  [108 109 110 111]]]
------------------------------
[array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]]), 
array([[100, 101, 102, 103],
       [104, 105, 106, 107],
       [108, 109, 110, 111]])]
------------------------------
[array([[  0,   1,   2,   3],
       [100, 101, 102, 103]]), 
 array([[  4,   5,   6,   7],
       [104, 105, 106, 107]]), 
 array([[  8,   9,  10,  11],
       [108, 109, 110, 111]])]

你可能感兴趣的:(python,Tensorflow-API)