参考 tf.py_func() - 云+社区 - 腾讯云
tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np的运算,然后返回tensor这样可以加强tensorflow的灵活性。在目标检测算法Faster R-CNN中,需要计算各种ground truth,接口比较复杂。因此,使用tf.py_func是一个比较好的途径。对于tf.py_func的使用,可以参见计算RPN的ground truth和计算proposals的ground truth时的使用方法。可以看到,都是将tensor转化成numpy array,再使用np.操作完成复杂运算。封装一个python函数并将其用作TensorFlow op。
tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
参数:
返回值:
import numpy as np
import tensorflow as tf
def my_func(array1,array2):
return array1 + array2, array1 - array2
if __name__ =='__main__':
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
a1 = tf.placeholder(tf.float32,[2,2],name = 'array1')
a2 = tf.placeholder(tf.float32,[2,2],name = 'array2')
# 函数、输入、输出类型
y1,y2 = tf.py_func(my_func,[a1,a2],[tf.float32, tf.float32])
with tf.Session() as sess:
y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2})
print(y1_)
print('*'*10)
print(y2_)
Output:
-----------
[[2. 4.]
[6. 8.]]
**********
[[0. 0.]
[0. 0.]]
-----------
直接用array的方式操作:
import tensorflow as tf
import numpy as np
def my_func(array1,array2):
return array1 + array2, array1 - array2
with tf.Session() as sess:
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
y1, y2 = my_func(array1, array2)
print(y1)
print('*' * 10)
print(y2)
Output:
-----------
[[2 4]
[6 8]]
**********
[[0 0]
[0 0]]
-----------