tf.while_loop使用list参数

tensorflow使用计算图的模型,所以常规for循环在tensorflow其实是不起作用的。

所以tensorflow提供了while_loop函数:

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    name=None,
    maximum_iterations=None,
    return_same_structure=False
)

具体参数就不一一介绍了,可以通过api文档或者help查询了解。

这里提一下当参数中存在list类型的变量时会产生的问题,

当参数存在list时,很容易会得到一些问题,如:

ValueError: Number of inputs and outputs of body must match loop_vars: 1, 2

这是通常是因为在body里对list进行了一些append一类的增或删操作,导致参数shape不匹配。

下面给出一种解决方案:

import tensorflow as tf

out = tf.Variable([])
i = tf.constant(0)

def cond(i, _):
    return i < 10

def body(i, out):
    i = i + 1
    out = tf.concat([out, [1.0]], 0)
    return [i, out]

_, out = tf.while_loop(cond, body, [i, out], shape_invariants=[i.get_shape(), tf.TensorShape([None])])

sess = tf.Session()
sess.run(tf.global_variables_initializer())
res = sess.run([_, out])
print(res) #  [10, array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)]

这里指定了shape_invariants参数,因为在while_loop中不希望参数的shape发生变化,因此在这里指定好shape,给一个tf.TensorShape([None])即自动推断长度,而不是固定检查,这样就可以解决list的长度在变化的问题了。

 

参考内容:https://stackoverflow.com/questions/41233462/tensorflow-while-loop-dealing-with-lists/41240808#41240808

 

你可能感兴趣的:(python,tensorflow)