我看别人都有转载请声明出处,我也写上,:)
转载请申明出处,https://blog.csdn.net/sinat_28704977/article/details/80626689
更重要的,如有错误,请批评指正,不胜感谢。
先上代码:
import numpy as np
import tensorflow as tf
a = tf.constant([
[ [ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ],
[ 8.0, 7.0, 6.0, 5.0 ],
[ 4.0, 3.0, 2.0, 1.0 ] ],
[ [ 4.0, 3.0, 2.0, 1.0 ],
[ 8.0, 7.0, 6.0, 5.0 ],
[ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ] ]
])
c=a
image_shape = c.get_shape()
b=tf.reshape(a,[4,4,2])
a = tf.reshape(a, [ 1, 4, 4, 2 ])
#这里reshape是强制转换,是直接从上面矩阵挨个取值并转换为目标shape,转换后的图片不是如上图一样的两个4*4数组即(2,4,4),而是(1,4,4,2)。
with tf.Session() as sess:
g,h=(c,image_shape[-1].value)
d = image_shape[ 1: ].as_list()#[4,4]一维列表
dim=1
for test in d:#image_shape为(2,4,4)tensor_shape类型,维数为3
print(test)
dim*=int(test)
e=tf.reshape(a,[-1,dim])
f=tf.reshape(e,(32,1))
print(e,'\n',e.get_shape())
image = sess.run(e)
image2 = sess.run(f)
print(image,image2)
代码中,a的shape为(2,4,4),并不是我们直观认为的图像格式4*4*2。这是需要注意的一点。
首先,tensor有get_shape()方法获得tensor的shape,类型为tensor_shape,然后,shape[1:]的意思是取从第二个维度开始的shape,例如返回的tensor_shape为(1,4,4,2),加上as_list()就成为[4,4,2]一个一维列表.
e=tf.reshape(a,[-1,dim])
这里不是转为列向量,因为如果是一个列向量,那么shape应该是(1,32,1),如果是列向量,则是(32,1)。这里转为:最后一个维度是32的一个向量,也就是(1,32)。要区分清楚。
结果图: