tf.map_fn()函数

tf.map_fn()函数定义如下:

tf.map_fn(

    fn,

    elems,

    dtype=None,

    parallel_iterations=10,

    back_prop=True,

    swap_memory=False,

    infer_shape=True,

    name=None

)

把函数当参数传进去,可以直接用lambda。将参数elems从第一维展开,进行map处理。一个简单的例子:

def fun1(a):

    if a.shape == (1,2):

        print("ok")

    else:

        raise Exception("shape error")

    return a * 2

var1 =  np.random.randint(10, size=(2,1,2))

var2 =  np.random.randint(10, size=(1,2))

print(var1)

# fun1(var1) 执行错误 shape error

#fun1(var2) 可以执行

# 执行

rtn = tf.map_fn(fun1, var1)

# 结果打印

with tf.Session() as sess:

    result = sess.run(rtn)

    print(result)

最后的执行结果:

[[[3 3]]

[[5 0]]]

ok

[[[ 6  6]]

[[10  0]]]

你可能感兴趣的:(tf.map_fn()函数)