TensorFlow 获取张量形状的操作 tf.shape()、属性shape 及 方法get_shape() 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

1、tf.shape(tensor)

获取输入张量 input 的形状,以 1 维整数张量形式表示

https://tensorflow.google.cn/api_docs/python/tf/shape

tf.shape(
    input,
    name=None,
    out_type=tf.int32
)

参数:

input:张量或稀疏张量

name:可选参数,操作的名称

out_type:可选参数,指定输出张量的数据类型(int32 或 int64),默认是 tf.int32

返回:

指定 out_type 数据类型的张量

 

2、tensor.shape

张量的形状属性

返回一个表示该张量的形状 tf.TensorShape

对于每个操作,通过注册在 Op 中的形状推断函数来计算该张量的形状,形状表示的更多信息请参考 tf.TensorShape

不需要在会话中启动图 Graph 的情况下,张量的推断形状用来表示形状信息。该信息可以用来调试,提供早期的错误信息

在某些情况下,推断出的形状可能存在未知的维度。如果调用者有关于这些维度值的额外信息,就可以使用 Tensor.set_shape() 来拓展该推断的形状 

https://tensorflow.google.cn/api_docs/python/tf/Tensor

 

3、tensor.get_shape()

tensor.shape 的别名,即同样返回一个表示该张量的形状 tf.TensorShape 的方法

https://tensorflow.google.cn/api_docs/python/tf/Tensor

注意:tensor.get_shape(),不是get_shapes

 

三、实例

1、操作 tf.shape()、属性shape 及 方法get_shape() 的基本用法

>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v


# tf.shape() 方法
>>> tf.shape(v)


# shape 属性
>>> v.shape
TensorShape([Dimension(100), Dimension(100)])

# get_shape() 方法
>>> v.get_shape()
TensorShape([Dimension(100), Dimension(100)])

# 错误的用法举例
# 将属性当成方法
>>> v.shape()
Traceback (most recent call last):
  File "", line 1, in 
TypeError: 'TensorShape' object is not callable

# 将方法当成属性
>>> v.get_shape
>

2、操作 tf.shape() 及属性shape 与 方法get_shape() 的区别

(1)操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出

(2)方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape,tf.TensorShape 可以通过 as_list() 方法将形状转换为列表形式

https://tensorflow.google.cn/api_docs/python/tf/TensorShape

# 操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
# 方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape

>>> import tensorflow as tf
>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v




# tf.shape() 方法
>>> tensor_shape = tf.shape(v)
>>> tensor_shape


# tf.shape() 返回的 Tensor 没有 as_list() 方法,所以报错
>>> tensor_shape.as_list()
Traceback (most recent call last):
  File "", line 1, in 
AttributeError: 'Tensor' object has no attribute 'as_list'

# tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
>>> with tf.Session() as sess:
...     print(sess.run(tensor_shape))
...
[100 100]




# shape 属性返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_1 = v.shape
>>> shapes_1
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_1 = shapes_1.as_list()
>>> shapes_list_1
[100, 100]




# get_shape() 方法返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_2 = v.get_shape()
>>> shapes_2
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_2 = shapes_2.as_list()
>>> shapes_list_2
[100, 100]

 

你可能感兴趣的:(TensorFlow基础)