一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
1、tf.shape(tensor)
获取输入张量 input 的形状,以 1 维整数张量形式表示
https://tensorflow.google.cn/api_docs/python/tf/shape
tf.shape(
input,
name=None,
out_type=tf.int32
)
参数:
input:张量或稀疏张量
name:可选参数,操作的名称
out_type:可选参数,指定输出张量的数据类型(int32 或 int64),默认是 tf.int32
返回:
指定 out_type 数据类型的张量
2、tensor.shape
张量的形状属性
返回一个表示该张量的形状 tf.TensorShape
对于每个操作,通过注册在 Op 中的形状推断函数来计算该张量的形状,形状表示的更多信息请参考 tf.TensorShape
不需要在会话中启动图 Graph 的情况下,张量的推断形状用来表示形状信息。该信息可以用来调试,提供早期的错误信息
在某些情况下,推断出的形状可能存在未知的维度。如果调用者有关于这些维度值的额外信息,就可以使用 Tensor.set_shape() 来拓展该推断的形状
https://tensorflow.google.cn/api_docs/python/tf/Tensor
3、tensor.get_shape()
tensor.shape 的别名,即同样返回一个表示该张量的形状 tf.TensorShape 的方法
https://tensorflow.google.cn/api_docs/python/tf/Tensor
注意:tensor.get_shape(),不是get_shapes
三、实例
1、操作 tf.shape()、属性shape 及 方法get_shape() 的基本用法
>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v
# tf.shape() 方法
>>> tf.shape(v)
# shape 属性
>>> v.shape
TensorShape([Dimension(100), Dimension(100)])
# get_shape() 方法
>>> v.get_shape()
TensorShape([Dimension(100), Dimension(100)])
# 错误的用法举例
# 将属性当成方法
>>> v.shape()
Traceback (most recent call last):
File "", line 1, in
TypeError: 'TensorShape' object is not callable
# 将方法当成属性
>>> v.get_shape
>
2、操作 tf.shape() 及属性shape 与 方法get_shape() 的区别
(1)操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
(2)方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape,tf.TensorShape 可以通过 as_list() 方法将形状转换为列表形式
https://tensorflow.google.cn/api_docs/python/tf/TensorShape
# 操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
# 方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape
>>> import tensorflow as tf
>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v
# tf.shape() 方法
>>> tensor_shape = tf.shape(v)
>>> tensor_shape
# tf.shape() 返回的 Tensor 没有 as_list() 方法,所以报错
>>> tensor_shape.as_list()
Traceback (most recent call last):
File "", line 1, in
AttributeError: 'Tensor' object has no attribute 'as_list'
# tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
>>> with tf.Session() as sess:
... print(sess.run(tensor_shape))
...
[100 100]
# shape 属性返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_1 = v.shape
>>> shapes_1
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_1 = shapes_1.as_list()
>>> shapes_list_1
[100, 100]
# get_shape() 方法返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_2 = v.get_shape()
>>> shapes_2
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_2 = shapes_2.as_list()
>>> shapes_list_2
[100, 100]