tensorflow scan()函数

上篇博客讲了theano scan(),这里主要比较他们的差别,其实这两个函数都是用来做迭代的,其实看官网API也写得很详细,也不难。

官方文档

tf.scan(
    fn,
    elems,
    initializer=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    reverse=False,
    name=None
)

scan on the list of tensors unpacked from elems on dimension 0.

scan 沿着elems 的第一维展开。如果elems =([[1, 2], [3, 4]),那么第一次迭代函数fn的输入就是[1, 2],第二次迭代函数fn的输入就是[3, 4]

The simplest version of scan repeatedly applies the callable fn to a sequence of elements from first to last. The elements are made of the tensors unpacked from elems on dimension 0. The callable fn takes two tensors as arguments. The first argument is the accumulated value computed from the preceding invocation of fn. If initializer is None, elems must contain at least one element, and its first element is used as the initializer.

最简单的scan 就是反复调用fn应用于从elems sequences中,这个elems 沿第一维展开。fn 输入是两个tensor作为参数,第一个参数是上一次输出结果,如果设置initilizal=None, 那么elems至少要有一个元素,并且第一个元素作为initializer
这是tf.scan和th.scan的主要区别之一,th.scan的上一次结果保存在输出updates中, 而tf.scan的上一次结果直接返回给了fn,作为他的第一个参数。

Suppose that elems is unpacked into values, a list of tensors. The shape of the result tensor is [len(values)] + fn(initializer, values[0]).shape. If reverse=True, it’s fn(initializer, values[-1]).shape.

假设elmes 解压之后为values, 一个张量序列。那么输出张量就是[len(values)] + fn(initializer, values[0]).shape ,如果设置reverse=True, 输出张量就是: [len(values)] + fn(initializer, values[-1]).shape.

看第一个code 就知道了:

elems = np.array([1, 2, 3, 4, 5, 6])
# 这里没有设置initialzer,所以elmes的第一个值就是initialer,就是a的初始值, 作为fn的第0次输出。
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x: a + x, elems, reverse=True)
# sum == [22, 21, 18, 15, 11, 6]

This method also allows multi-arity elems and accumulator. If elems is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The second argument of fn must match the structure of elems.

该方法还允许多个元素和累加器。 如果elems是(可能是嵌套的)列表或张量元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。 fn的第二个参数必须与elems的结构相匹配。就是所输入的每个elem的第一维度一定要相同

If no initializer is provided, the output structure and dtypes of fn are assumed to be the same as its input; and in this case, the first argument of fn must match the structure of elems.

如果没有提供initializer,则假定fn的输出结构和dtypes与其输入相同; 在这种情况下,fn的第一个参数必须与elems的结构相匹配。

If an initializer is provided, then the output of fn must have the same structure as initializer; and the first argument of fn must match this structure.

如果提供initializer,则fn的输出必须与初始化程序具有相同的结构; 并且fn的第一个参数必须与此结构匹配。这是和th.scan的二个区别,initializer有点像th.scan中的non_sequence + output_info

For example, if elems is (t1, [t2, t3]) and initializer is [i1, i2] then an appropriate signature for fn in python2 is: fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]): and fn must return a list, [acc_n1, acc_n2]. An alternative correct signature for fn, and the one that works in python3, is: fn = lambda a, t:, where a and t correspond to the input tuples.

例如,如果elems是(t1,[t2,t3])并且初始化器是[i1,i2]那么python2中fn的适当签名是:fn = lambda(acc_p1,acc_p2),(t1,[t2,t3] ):和fn必须返回一个列表,[acc_n1,acc_n2]。 fn的另一个正确签名,以及在python3中工作的签名是:fn = lambda a,t:,其中a和t对应于输入元组。

多输入
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
    lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]

多输出

elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])

你可能感兴趣的:(深度学习)