Tensorflow函数映射:py_func和map_fn

tf.map_fn

[tf.map_fn]:map on the list of tensors unpacked from elems on dimension 0. 接受一个函数对象,然后用该函数对象对集合(elems)中的每一个元素分别处理,

tf.map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=None,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

参数解析

        dtype: (optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn. 如果输入输出类型不一样,这个一定要加上,否则输出类型自动和输入一样,报错。
-柚子皮-

 

 

tf.py_func

[tf.py_func]: Wraps a python function and uses it as a TensorFlow op. 用来将 一个 python 函数打包成一个 op。

Note: 杯具的是,如果在estimator,如predict阶段使用py_func,它是不能在后续server中使用的。即export_saved_model后,在predictor.from_saved_model时会出错,说ValueError: callback pyfunc_0 is not found,即estimator不会将py_func拉进静态图中?这时只能放弃py_func,类似下面示例中的解决方案2了。

        Tensorflow还是有不足的地方。第一体现在Tensorflow的数据机制,由于tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的。因此,在网络搭建的时候,是不能对tensor进行判值操作的,即不能插入if...else...之类的代码。第二,相较于numpy array,Tensorflow中对tensor的操作接口灵活性并没有那么高,使得Tensorflow的灵活性减弱。扩展Tensorflow程序的灵活性,有一个重要的手段,就是使用tf.py_func接口。tf.py_func()运算符使您可以在TensorFlow图的中间运行任意Python代码。包装自定义NumPy运算符特别方便,因为没有等效的TensorFlow运算符(尚未存在)。添加tf.py_func()是在图中使用sess.run()调用的替代方法,这样就可以使用任意python函数操作tensor了(一般的py函数不能操作tensor的,只能数值)。

tf.py_func(
    func,
    inp,
    Tout,
    stateful=True,
    name=None
)

        tf.py_func的原理:首先,tf.py_func接收的是tensor,然后将其转化为numpy array送入func函数,最后再将func函数输出的numpy array转化为tensor返回。

参数解析

        func: 一个python函数,它将一个Numpy数组组成的list作为输入,该list中的元素的数据类型和inp参数中的tf.Tensor对象的数据类型相对应,同时该函数返回一个Numpy数组组成的list或者单一的Numpy数组,其数据类型和参数Tout中的值相对应

        inp: Tensor队形组成的list,即使只有一个tensor也需要使用[tensor]

        Tout: 该函数的返回对象的数据类型。一个tensorflow数据类型组成的list或者tuple(如[tf.string, tf.string]),(如果只有一个返回值,需要单独一个tensorflow数据类型,如tf.string,不要写成[tf.string],这样返回时也会多一维)。

        stateful:布尔值,如果该值为True,该函数应被视为与状态有关的。如果一个函数与状态无关,则相同的输入会产生相同的输出,并不会产生明显的副作用。有些优化操作如common subexpression elimination只能在与状态无关的操作中进行。

        注意:

1 func函数的返回值类型一定要和Tout指定的tensor类型一致。

2 The body of the function (i.e. func) will not be serialized in a GraphDef. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. tf.py_func中的func是脱离Graph的,在func中不能定义可训练的参数参与网络训练(反传),或者说无法求导。

3 如果python_func()函数有 string 参数的话,tf会把这个string参数 转换成 bytes 类型。

函数示例

        (可能可以使用np解决)一个不好的地方是,如果返回多个数据,有n个数据那必须指定Tout = [tf.string]*n。[Returning mutiple values in the input function for `tf.py_func`]如果输入m个数据,但是Tout = [tf.string]*n就会报错:InvalidArgumentError (see above for traceback): pyfunc returns m values, but expects to see n values.出现这种情况一般发生在tf batch训练时,因为最后一个batch_size是不固定且很可能不等于指定的params['batch_size'],而且这里的n不能直接设置成输入inp.shape[0],因为它不固定,也是返回一个None值,而不是实际的batch大小。

pred_strings = tf.py_func(mlb.inverse_transform, [pred_ids], [tf.string] * params['batch_size'])

pred_strings = tf.convert_to_tensor(pred_strings, dtype=tf.string)

        1 一种解决方案是将tf.py_func外加一层tf.map_fn,这样tf.py_func每次都只执行一个数据,Tout = tf.string就可以。

pred_strings = tf.map_fn(lambda x: tf.py_func(mlb.inverse_transform, [tf.expand_dims(x, 0)], tf.string), pred_ids, tf.string)

       2 还有一种是在estimator外层的sess中执行,这样pred_ids就不是tensor,而是数值,可以直接使用python函数操作。

from: -柚子皮-

ref:

 

你可能感兴趣的:(tensorflow)