[tf.map_fn]:map on the list of tensors unpacked from elems on dimension 0. 接受一个函数对象,然后用该函数对象对集合(elems)中的每一个元素分别处理,
tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
参数解析
dtype: (optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn. 如果输入输出类型不一样,这个一定要加上,否则输出类型自动和输入一样,报错。
-柚子皮-
[tf.py_func]: Wraps a python function and uses it as a TensorFlow op. 用来将 一个 python 函数打包成一个 op。
Note: 杯具的是,如果在estimator,如predict阶段使用py_func,它是不能在后续server中使用的。即export_saved_model后,在predictor.from_saved_model时会出错,说ValueError: callback pyfunc_0 is not found,即estimator不会将py_func拉进静态图中?这时只能放弃py_func,类似下面示例中的解决方案2了。
Tensorflow还是有不足的地方。第一体现在Tensorflow的数据机制,由于tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的。因此,在网络搭建的时候,是不能对tensor进行判值操作的,即不能插入if...else...之类的代码。第二,相较于numpy array,Tensorflow中对tensor的操作接口灵活性并没有那么高,使得Tensorflow的灵活性减弱。扩展Tensorflow程序的灵活性,有一个重要的手段,就是使用tf.py_func接口。tf.py_func()运算符使您可以在TensorFlow图的中间运行任意Python代码。包装自定义NumPy运算符特别方便,因为没有等效的TensorFlow运算符(尚未存在)。添加tf.py_func()是在图中使用sess.run()调用的替代方法,这样就可以使用任意python函数操作tensor了(一般的py函数不能操作tensor的,只能数值)。
tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
tf.py_func的原理:首先,tf.py_func接收的是tensor,然后将其转化为numpy array送入func函数,最后再将func函数输出的numpy array转化为tensor返回。
func: 一个python函数,它将一个Numpy数组组成的list作为输入,该list中的元素的数据类型和inp参数中的tf.Tensor对象的数据类型相对应,同时该函数返回一个Numpy数组组成的list或者单一的Numpy数组,其数据类型和参数Tout中的值相对应
inp: Tensor队形组成的list,即使只有一个tensor也需要使用[tensor]。
Tout: 该函数的返回对象的数据类型。一个tensorflow数据类型组成的list或者tuple(如[tf.string, tf.string]),(如果只有一个返回值,需要单独一个tensorflow数据类型,如tf.string,不要写成[tf.string],这样返回时也会多一维)。
stateful:布尔值,如果该值为True,该函数应被视为与状态有关的。如果一个函数与状态无关,则相同的输入会产生相同的输出,并不会产生明显的副作用。有些优化操作如common subexpression elimination只能在与状态无关的操作中进行。
注意:
1 func函数的返回值类型一定要和Tout指定的tensor类型一致。
2 The body of the function (i.e. func) will not be serialized in a GraphDef. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. tf.py_func中的func是脱离Graph的,在func中不能定义可训练的参数参与网络训练(反传),或者说无法求导。
3 如果python_func()函数有 string 参数的话,tf会把这个string参数 转换成 bytes 类型。
(可能可以使用np解决)一个不好的地方是,如果返回多个数据,有n个数据那必须指定Tout = [tf.string]*n。[Returning mutiple values in the input function for `tf.py_func`]如果输入m个数据,但是Tout = [tf.string]*n就会报错:InvalidArgumentError (see above for traceback): pyfunc returns m values, but expects to see n values.出现这种情况一般发生在tf batch训练时,因为最后一个batch_size是不固定且很可能不等于指定的params['batch_size'],而且这里的n不能直接设置成输入inp.shape[0],因为它不固定,也是返回一个None值,而不是实际的batch大小。
pred_strings = tf.py_func(mlb.inverse_transform, [pred_ids], [tf.string] * params['batch_size'])
pred_strings = tf.convert_to_tensor(pred_strings, dtype=tf.string)
1 一种解决方案是将tf.py_func外加一层tf.map_fn,这样tf.py_func每次都只执行一个数据,Tout = tf.string就可以。
pred_strings = tf.map_fn(lambda x: tf.py_func(mlb.inverse_transform, [tf.expand_dims(x, 0)], tf.string), pred_ids, tf.string)
2 还有一种是在estimator外层的sess中执行,这样pred_ids就不是tensor,而是数值,可以直接使用python函数操作。
from: -柚子皮-
ref: