tf.split( )和tf.unstack( )

import tensorflow as tf

A = [[1, 2, 3], [4, 5, 6]]
a0 = tf.split(A, num_or_size_splits=3, axis=1)#不改变维数(!!)
a1 = tf.unstack(A, num=3,axis=1)
a2 = tf.split(A, num_or_size_splits=2, axis=0)
a3 = tf.unstack(A, num=2,axis=0)
with tf.Session() as sess:
    print(sess.run(a0))
    print(sess.run(a1))
    print(sess.run(a2))
    print(sess.run(a3))
[array([[1],[4]]), array([[2],[5]]), array([[3],[6]])] 

[array([1, 4]), array([2, 5]), array([3, 6])] 

[array([[1, 2, 3]]), array([[4, 5, 6]])] 

[array([1, 2, 3]), array([4, 5, 6])]

你可能感兴趣的:(tensorflow)