Tensorflow深度学习之二十五:tf.py_func

一、简介

def py_func(func, inp, Tout, stateful=True, name=None)

  该函数重构一个python函数,并将其作为一个TensorFlow的op使用。
  给定一个输入和输出都是numpy数组的python函数’func’,py_func函数将func重构进TensorFlow的计算图中。

  例如:

def my_func(x):
    # x will be a numpy array with the contents of the placeholder below
    return np.sinh(x)
inp = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [inp], tf.float32)



  参数如下:

参数 作用
func A Python function, which accepts a list of NumPy ndarray objects having element types that match the corresponding tf.Tensor objects in inp, and returns a list of ndarray objects (or a single ndarray) having element types that match the corresponding values in Tout.

一个python函数,它将一个Numpy数组组成的list作为输入,该list中的元素的数据类型和inp参数中的tf.Tensor对象的数据类型相对应,同时该函数返回一个Numpy数组组成的list或者单一的Numpy数组,其数据类型和参数Tout中的值相对应。
inp A list of Tensor objects.

Tensor队形组成的list。
Tout A list or tuple of tensorflow data types or a single tensorflow data type if there is only one, indicating what func returns.

一个tensorflow数据类型组成的list或者tuple,(如果只有一个返回值,可以是单独一个tensorflow数据类型),表明该函数的返回对象的数据类型。
stateful (Boolean.) If True, the function should be considered stateful. If a function is stateless, when given the same input it will return the same output and have no observable side effects. Optimizations such as common subexpression elimination are only performed on stateless operations.

布尔值,如果该值为True,该函数应被视为与状态有关的。如果一个函数与状态无关,则相同的输入会产生相同的输出,并不会产生明显的副作用。有些优化操作如common subexpression elimination只能在与状态无关的操作中进行。
name 操作的名称

二、代码示例

import tensorflow as tf
import numpy as np


# 定义一个函数,输入为两个array,返回他们的加,减,以及点乘和叉乘
def my_function(array1, array2):
    return array1 + array2, array1 - array2, array1 * array2, np.dot(array1, array2)


if __name__ == '__main__':
    array1 = np.array([[1, 2], [3, 4]])
    array2 = np.array([[5, 6], [7, 8]])

    a1 = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='array1')
    a2 = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='array2')

    # 重构函数
    y1, y2, y3, y4 = tf.py_func(my_function, [a1, a2], [tf.float32, tf.float32, tf.float32, tf.float32])

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        _y1, _y2, _y3, _y4 = sess.run([y1, y2, y3, y4], feed_dict={a1: array1, a2: array2})
        print(_y1)
        print('*' * 8)
        print(_y2)
        print('*' * 8)
        print(_y3)
        print('*' * 8)
        print(_y4)

  结果如下:

[[ 6.  8.]
 [10. 12.]]
********
[[-4. -4.]
 [-4. -4.]]
********
[[ 5. 12.]
 [21. 32.]]
********
[[19. 22.]
 [43. 50.]]

你可能感兴趣的:(深度学习,Tensorflow)