在tensorflow中经常用到shape函数
例如
import tensorflow as tf
a = tf.constant([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]],shape = [3,3])
b = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(b)
print('data=[[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]] , shape = [3,3] ')
print('display:\n',sess.run(a))
print('\n')
其中 shape = [3,3]表示一个3*3的二维数据,其组织形式为
[[1. 2. 3.]
[4. 5. 6.]
[7. 8. 9.]]
也可以表示为[[1. 2. 3.] [4. 5. 6.] [7. 8. 9.]],实质是一样的。
再复杂一点的,例如
a = tf.constant([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]],shape = [3,3,1])
其组织形式为
[[[1.] [2.] [3.]] [[4.] [5.] [6.]] [[7.] [8.] [9.]]]
可以看到每个数据自成一组,这就是 shape = [3,3,1]中最右边的1带来的效果。
对于通用的形式 shape = [s1,s2,s3,s4,s5],如何理解呢?
规则:从右到左分组
举例:
a = tf.constant([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]],shape = [1,3,3,2])
其组织形式为
[ [ [[1. 2.] [3. 4.] [5. 6.]] [[7. 8.] [9. 9.] [9. 9.]] [[9. 9.] [9. 9.] [9. 9.]] ] ]
根据shape=[1,3,3,2],从右到左,首先按照每2个数编一组形成group1,再每3个group1组编一组成group2,再每3个group2便一组成group3,最后整个group3又编一个大组group4。所以最外层的方括号有4个。
当数据不够,后续的分组中都以尾数据9.填充,强制形成[1,3,3,2]格式的数组,1*3*3*2共18个数据。
reshape函数也可以参照shape函数进行理解。