tf.concat

tf.concat = (tensor, axis=0)

tensor需要拼接的张量
axis维度,当axis=0时,在第0个维度拼接(按行拼接,即列拼接),当axis=1时,在第1个维度拼接(按列拼接,即行拼接)

t1 = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
t2 = tf.constant([[7, 8, 9], [10, 11, 12]], dtype=tf.float32)
t3 = [[1, 2, 3], [4, 5, 6]]
t4 = [[7, 8, 9], [10, 11, 12]]
T1 = tf.concat([t1, t2], 0)
T2 = tf.concat([t1, t2], 1)
T3 = tf.concat([t3, t4], 0)
T4 = tf.concat([t3, t4], 1)
with tf.Session() as sess:
  print('维度0拼接:\n',sess.run(T1))   # 浮点数
  print('='*30)
  print('维度1拼接:\n',sess.run(T2))  # 浮点数
  print('维度0拼接:\n',sess.run(T3))  # 整数
  print('='*30)
  print('维度1拼接:\n',sess.run(T4))  # 整数

维度0拼接:
 [[ 1.  2.  3.]
 [ 4.  5.  6.]
 [ 7.  8.  9.]
 [10. 11. 12.]]
==============================
维度1拼接:
 [[ 1.  2.  3.  7.  8.  9.]
 [ 4.  5.  6. 10. 11. 12.]]

维度0拼接:
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
==============================
维度1拼接:
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]

你可能感兴趣的:(tf.concat)