tf.split 和 tf.concat

import tensorflow as tf
import numpy as np
X = np.random.random([1,3,3,4])
X
array([[[[ 0.82959287,  0.97123702,  0.28140139,  0.27116128],
         [ 0.17657325,  0.95732474,  0.69869441,  0.68558369],
         [ 0.27456733,  0.75242884,  0.00578983,  0.36427501]],

        [[ 0.55055599,  0.27293508,  0.58177528,  0.60010759],
         [ 0.49096017,  0.03448037,  0.77094952,  0.72902519],
         [ 0.72496438,  0.57176329,  0.9313365 ,  0.81825572]],

        [[ 0.35645042,  0.79323193,  0.08155452,  0.75811829],
         [ 0.24662546,  0.20411053,  0.19005582,  0.72657277],
         [ 0.84135906,  0.77598372,  0.26645642,  0.69704092]]]])
splits = tf.split(axis=3, num_or_size_splits=2, value=X)
splits
[,
 ]
sess =  tf.Session()
splits_res = sess.run(splits)
for i in splits_res:
    print(i)
    print()
    print()
[[[[ 0.82959287  0.97123702]
   [ 0.17657325  0.95732474]
   [ 0.27456733  0.75242884]]

  [[ 0.55055599  0.27293508]
   [ 0.49096017  0.03448037]
   [ 0.72496438  0.57176329]]

  [[ 0.35645042  0.79323193]
   [ 0.24662546  0.20411053]
   [ 0.84135906  0.77598372]]]]


[[[[ 0.28140139  0.27116128]
   [ 0.69869441  0.68558369]
   [ 0.00578983  0.36427501]]

  [[ 0.58177528  0.60010759]
   [ 0.77094952  0.72902519]
   [ 0.9313365   0.81825572]]

  [[ 0.08155452  0.75811829]
   [ 0.19005582  0.72657277]
   [ 0.26645642  0.69704092]]]]
splits_concat = tf.concat(axis=3, values=splits_res)
splits_concat_res = sess.run(splits_concat)
splits_concat_res
array([[[[ 0.82959287,  0.97123702,  0.28140139,  0.27116128],
         [ 0.17657325,  0.95732474,  0.69869441,  0.68558369],
         [ 0.27456733,  0.75242884,  0.00578983,  0.36427501]],

        [[ 0.55055599,  0.27293508,  0.58177528,  0.60010759],
         [ 0.49096017,  0.03448037,  0.77094952,  0.72902519],
         [ 0.72496438,  0.57176329,  0.9313365 ,  0.81825572]],

        [[ 0.35645042,  0.79323193,  0.08155452,  0.75811829],
         [ 0.24662546,  0.20411053,  0.19005582,  0.72657277],
         [ 0.84135906,  0.77598372,  0.26645642,  0.69704092]]]])
splits_concat_res.shape
(1, 3, 3, 4)

你可能感兴趣的:(tf.split 和 tf.concat)