TensorFlow的tf.concat实例详细介绍

tf.concat函数:函数功能比较简单,主要用于连接两个数组

参数:
values:需要连接的数组,注意数组的维度应该一致
axis:从哪个维度来连接数组

例子:

1.一维数组

 import tensorflow as tf

 if __name__ == "__main__":
    a = [1,2,3]
    b = [4,5,6]
    c = tf.concat([a,b],0)
    sess = tf.InteractiveSession()
    print(sess.run(c)) #[1 2 3 4 5 6]

注意:axis参数不能超过数组的维度。如果超过数组的维度,如下:

  c = tf.concat([a,b],1)

则会报,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat',意思是数组至少是二维,axis才能为1。

2.二维数组

  a = [[1,1],[2,2],[3,3]]
  b = [[4,4],[5,5],[6,6]]
  c = tf.concat([a,b],0)
  print(sess.run(c))
"""
[[1 1]
 [2 2]
 [3 3]
 [4 4]
 [5 5]
 [6 6]]
"""
  c = tf.concat([a,b],1) #等价于tf.concat([a,b],-1)
  print(sess.run(c))
"""
[[1 1 4 4]
 [2 2 5 5]
 [3 3 6 6]]
"""

3.三维数组

  a = [[[1,1],[2,2]],[[3,3],[4,4]]]
  b = [[[5,5]],[[6,6]]]

  c = tf.concat([a,b],1)
  print(sess.run(c))
"""
[[[1 1]
  [2 2]
  [5 5]]

 [[3 3]
  [4 4]
  [6 6]]]
"""
  1. a = [[1, 2], [3, 4]]
    b = [[5, 6]]
    c = np.concatenate((a, b), axis=None)
    """
     [[1,2,3,4,5,6]]
    """
    

5.如何来判断数组是否在该个维度上的shape是相同的呢?
其实很简单,我们根据tf.concat的axis参数来去数组的[],0表示去掉最外面的一层,1去掉两层,以此类推,下面举例说明一下。
如:最后一个例子中的c = tf.concat([a,b],1),我们先将a去掉最外面两层[],变成了[1,1],[2,2]和[3,3],[4,4]],然后再将b去掉最外面两层[],变成了[5,5]和[6,6],此时再进行concat,可以发现此时的shape是相等的。
参考:
https://blog.csdn.net/sinat_29957455/article/details/86100641

你可能感兴趣的:(TensorFlow的tf.concat实例详细介绍)