2. 占位符与张量

关闭tf的版本警告

import osos.environ['TF_CPP_MIN_LOG_LEVEL']='2'

占位符 tf.placeholder(dtype, shape=None, name=None)

import tensorflow as tf

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.add(a, b)

with tf.Session() as sess:
    print(sess.run(c, feed_dict={a: 1.1, b: 2.2}))

调用占位符相关的op时,要指定 feed_dict 参数

定义张量并获取属性 tf.constant(value, dtype=None, shape=None, name="Const", verify_shape=False)

a = tf.constant(3.0)
b = tf.constant(4.0)
c = a + b

with tf.Session() as sess:
    print('c.graph: ', c.graph)  # 张量所属的默认图
    print('c.name: ', c.name)   # 张量的字符串描述
    print('c.shape: ', c.shape)   # 张量形状
    print('c.op: ', c.op)   # 张量的操作名    
    print('c.eval: ', c.eval())
    print('sess.run(c): ', sess.run(c))

张量的动态形状与静态形状

TensorFlow中,张量具有静态形状动态形状

  • 静态形状:
    创建一个张量或者由操作推导出一个张量时,初始状态的形状
    tf.Tensor.get_shape: 获取静态形状
    tf.Tensor.set_shape(): 更新Tensor对象的静态形状,通常用于在不能直接推断的情况下。不产生新的张量。只能从1D->1D, 2D->2D ......
    example: [None, 30, 4] ----> [20, 30, 4]

  • 动态形状:
    一种描述原始张量在执行过程中的一种形状
    tf.reshape: 创建一个具有不同动态形状的新张量。产生新的张量。可以跨维度转换。但是总数量必须一致。
    example: [20, 5] ----> [4, 5, 5]

  • 注意:
    转换静态形状的时候,1-D到1-D,2-D到2-D,不能跨阶数改变形状
    对于已经固定或者设置静态形状的张量/变量,不能再次设置静态形状
    tf.reshape()动态创建新张量时,元素个数不能不匹配

批量生成张量 tf.zeros() tf.ones() tf.random_normal ()

# tf.zeros(shape, dtype=dtypes.float32, name=None)
a = tf.zeros([4, 10])  # 四行十列

# tf.ones(shape, dtype=dtypes.float32, name=None)
b = tf.ones([4, 10])  # 四行十列

# 随机正态分布数据 random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None):
c = tf.random_normal([4, 10])  # 四行十列

改变张量里的数据类型 tf.cast(x, dtype, name=None) 原张量不变,会产生一个新的张量

a = tf.constant(1.0)
print(a)  # ----> Tensor("Const:0", shape=(), dtype=float32)

b = tf.cast(a, tf.int32)
print(b)  # ----> Tensor("Cast:0", shape=(), dtype=int32)

张量的切片和扩展

tf.slice(input_, begin, size, name=None) 从begin索引开始, 获取size个数据. 溢出会报错
tf.concat(values, axis, name="concat")

a = tf.constant([1.1, 2.2, 3.3, 8.8, 9.9, 10.1])

b = tf.slice(a, [2], [4])

with tf.Session() as sess:
    print(sess.run(b))  # ---->  [3.3  8.8  9.9 10.1]

你可能感兴趣的:(2. 占位符与张量)