>>> import tensorflow as tf
>>> data_tensor = tf.constant([1,2,3,4,5,6], shape=[2,3], dtype=tf.float32)
>>> data_tensor
<tf.Tensor 'Const:0' shape=(2, 3) dtype=float32>
>>> with tf.Session() as sess:
# 方法一:
... print("sess.run(tensor): {}".format(sess.run(data_tensor)))
# 方法二:
... print("tensor.eval(session=sess): {}".format(data_tensor.eval(session=sess)))
...
# 打印结果:
sess.run(tensor): [[1. 2. 3.]
[4. 5. 6.]]
tensor.eval(session=sess): [[1. 2. 3.]
[4. 5. 6.]]
将包含指定数据的 tensor 输入给 Session 的 run() 方法的 feed_dict 参数,但是需要注意 feed_dict 参数不能接收 tensor,所以需要将 tensor 进行转换为 numpy array 形式
>>> import tensorflow as tf
>>> truncated_tensor = tf.random.truncated_normal([2,3],dtype=tf.float32)
>>> truncated_tensor
<tf.Tensor 'truncated_normal:0' shape=(2, 3) dtype=float32>
>>> X_holder = tf.placeholder(dtype=tf.float32, shape=[2,3])
>>> result = tf.split(X_holder,[1,2],-1)
>>> result
[<tf.Tensor 'split:0' shape=(2, 1) dtype=float32>, <tf.Tensor 'split:1' shape=(2, 2) dtype=float32>]
>>> with tf.Session() as sess:
... data = sess.run(truncated_tensor)
... matrix1, matrix2 = sess.run(fetches=result, feed_dict={X_holder:data})
... print("matrix1: {}".format(matrix1))
... print("matrix2: {}".format(matrix2))
...
matrix1: [[ 0.44706684]
[-1.1646993 ]]
matrix2: [[-1.026939 0.25133875]
[ 0.3521515 -0.21427773]]
错误方式举例:Session 的 run() 方法的参数 feed_dict 不能接收 tensor 类型
>>> truncated_tensor = tf.random.truncated_normal([2,3],dtype=tf.float32)
>>> truncated_tensor
<tf.Tensor 'truncated_normal:0' shape=(2, 3) dtype=float32>
>>> X_holder = tf.placeholder(dtype=tf.float32, shape=[2,3])
>>> result = tf.split(X_holder,[1,2],-1)
>>> result
[<tf.Tensor 'split:0' shape=(2, 1) dtype=float32>, <tf.Tensor 'split:1' shape=(2, 2) dtype=float32>]
>>> with tf.Session() as sess:
... matrix1, matrix2 = sess.run(fetches=result, feed_dict={X_holder:truncated_tensor})
... print("matrix1")
... print("matrix2")
...
Traceback (most recent call last):
File "" , line 2, in <module>
File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1124, in _run
'feed with key ' + str(feed) + '.')
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles. For reference, the tensor object was Tensor("truncated_normal:0", shape=(2, 3), dtype=float32) which was passed to the feed with key Tensor("Placeholder:0", shape=(2, 3), dtype=float32).
(PS:根据报错提示可知, tf.Tensor 不能直接传入给参数 feed_dict,该参数仅支持:Python 标量、字符串、列表,numpy n维数组,或者 TensorHandles 等类型)
>>> import numpy as np
>>> data = np.random.random([2,3])
>>> data
array([[0.02510445, 0.11428815, 0.11885889],
[0.34326366, 0.69007324, 0.70780292]])
>>> data_tensor = tf.convert_to_tensor(data)
>>> data_tensor
<tf.Tensor 'Const_2:0' shape=(2, 3) dtype=float64>
>>> sess = tf.Session()
>>> data_tensor.eval(session=sess)
array([[0.02510445, 0.11428815, 0.11885889],
[0.34326366, 0.69007324, 0.70780292]])
>>> sess.close()