之前一直弄混张量的维度和shape的关系,认为通过tf.shape()获得的就是维度,现在发现错误,记下来。
tf.shap()用来获取的是张量的各个维度上的元素数目。
1 #维度为0的标量
[1, 2, 3] #维度为1,包含3个元素
[[1, 2], [3, 4]] #维度为2, shape=(2, 2)
[[[1, 2], [3, 4]], [[1, 2], [3, 4]]] #维度为3, shape=(2, 2, 2)
sess = tf.Session()
a = tf.constant([1])
b = tf.constant([2])
c = tf.concat([a, b], axis=0)
sess.run(c)
d = tf.constant([2, 3, 4])
print("a: ", sess.run(tf.shape(a)), a)
print("b: ", sess.run(tf.shape(b)), b)
print("c: ", sess.run(tf.shape(c)), c)
print("d: ", sess.run(tf.shape(d)), d)
e = d[0]
print("e: ", sess.run(tf.shape(e)), e) #这里表明是标量
f = tf.reshape(e, [1,])
print("f: ", sess.run(tf.shape(f)), f)
g = tf.concat([f, a], axis=0)
print("g: ", sess.run(g), g)
h = d[:1]
print("h: ", sess.run(tf.shape(h)), h)
output:
a: [1] Tensor("Const_95:0", shape=(1,), dtype=int32)
b: [1] Tensor("Const_96:0", shape=(1,), dtype=int32)
c: [2] Tensor("concat_29:0", shape=(2,), dtype=int32)
d: [3] Tensor("Const_97:0", shape=(3,), dtype=int32)
e: [] Tensor("strided_slice_39:0", shape=(), dtype=int32)
f: [1] Tensor("Reshape_6:0", shape=(1,), dtype=int32)
g: [2 1] Tensor("concat_30:0", shape=(2,), dtype=int32)
h: [1] Tensor("strided_slice_40:0", shape=(1,), dtype=int32)
注意:e的输出为标量,因为这里只是获得其中一个元素
python类似。
张量切片:
i = tf.slice(d, [0],[2]) #d是一维数据,[0]表示从该一维数据的第0个元素开始切片,[2]表示第一维元素保留2个。
print(sess.run(tf.shape(i)), i)
sess.run(i)
output:
[2] Tensor("Slice_9:0", shape=(2,), dtype=int32)
array([2, 3])
elem_tf = tf.constant([i+1 for i in range(30)], shape=[5, 6], name="elem")
sess = tf.Session()
print(sess.run(elem), type(elem_tf))
elem_np = elem_tf.eval(session=sess) #看这里eval用法
print("\n", elem_np, type(elem_np))
elem_tf_convert = tf.convert_to_tensor(elem_np)
print("\n", sess.run(elem_tf_convert[0][0]), type(elem_tf_convert))
sess.close()
output:
[[ 1 2 3 4 5 6]
[ 7 8 9 10 11 12]
[13 14 15 16 17 18]
[19 20 21 22 23 24]
[25 26 27 28 29 30]]
[[ 1 2 3 4 5 6]
[ 7 8 9 10 11 12]
[13 14 15 16 17 18]
[19 20 21 22 23 24]
[25 26 27 28 29 30]]
1
注意对于张量元素的读取可以效仿python,直接使用x[0]这样的方式,如下:
sess = tf.Session()
a = tf.constant([1, 2, 3])
#d = tf.add(a[0], a[1]) #或者
d = a[0] + a[1]
print(sess.run(d))
sess.close()
output:
3
将标量转换成张量用于计算,如下:
sess = tf.Session()
aa = tf.constant([1, 1])
part1 = aa[0]
print(type(part1))
print(tf.shape(part1))
bb = tf.constant([1])
print(tf.shape(bb))
part1 = tf.reshape(part1, [1]) #注意转换维度,不然不能用,这里part1是标量
cc = tf.concat([part1, bb], axis=0)
sess.run(cc)
output:
Tensor("Shape_95:0", shape=(0,), dtype=int32)
Tensor("Shape_96:0", shape=(1,), dtype=int32)
array([1, 1])