理解Theano的Scan函数

1 Scan是干什么的

函数scan是Theano中迭代的一般形式,所以可以用于类似循环(looping)的场景。
如果你熟悉Reduction和map两个函数,这两个都是scan的特殊形式,即将某函数依次作用一个序列的每个元素上。
函数scan的输入也是一些序列(一维数组,或者多维数组,以第一维为leading dimension),将某个函数作用于输入序列上,得到每一步输出的结果。
和Reduction和map两个函数不同之处在于,scan在计算的时候,可以访问以前n步的输出结果,所以比较适合RNN网络。

2 为什么要使用scan

看起来scan完全可以用for… loop来代替,然而scan有其自身的优点:

  • 由于Theano是使用符号代数的,迭代的次数就自然成为符号代数的一部分。也就是说迭代次数也会体现在构造符号代数的图中。
    (Theano用一个图来表示符号代数)

  • 由于上面一条,可以直接用Theano计算梯度。

  • 优化减少CPU和GPU之间的数据传输,比Python Loop稍微快一点。

  • 说不定以后还会有符号代数的其他优点,例如自动优化 y = x/x*x。

3 大概参数说明

函数scan调用的一般形式的一个例子大概是这样:

results, updates = theano.scan(fn = lambda y, p, x_tm2, x_tm1,A: y+p+x_tm2+xtm1+A,
sequences=[Y, P[::-1]], 
outputs_info=[dict(initial=X, taps=[-2, -1])]), 
non_sequences=A)

*参数fn是一个你需要计算的函数,一般用函数定义(比较简单的可以用lambda来定义),参数是有顺序要求的,先是sequences的参数(y,p),然后是output_info的参数(x_tm2,x_tm1),然后是no_sequences的参数(A)。

*sequences就是需要迭代的序列,序列的第一个维度(leading dimension)就是需要迭代的次数。所以,Y和P[::-1]的第一维大小应该相同,如果不同的话,就会取最小的。

*outputs_info描述了需要用到前几次迭代输出的结果,dict(initial=X, taps=[-2, -1])表示使用前一次和前两次输出的结果。如果当前迭代输出为x(t),则计算中使用了x(t-1)和x(t-2)。 所以output_info中的dictionary个数应该和fn的输出个数是对应的。

*non_sequences描述了非序列的输入,即A是一个固定的输入,每次迭代加的A都是相同的。如果Y是一个向量,A就是一个常数。总之,A比Y少一个维度。

4 举例

计算 Ak , 大材小用一下

k = T.iscalar("k")
A = T.vector("A")

# Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)

# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
final_result = result[-1]
# compiled function that returns A**k
power = theano.function(inputs=[A,k], outputs=final_result, updates=updates)

print power(range(10),2)
print power(range(10),4)

输出:

[  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]
[  0.00000000e+00   1.00000000e+00   1.60000000e+01   8.10000000e+01
   2.56000000e+02   6.25000000e+02   1.29600000e+03   2.40100000e+03
   4.09600000e+03   6.56100000e+03]

计算 Computing tanh(x(t).dot(W)+b)

X = T.matrix("X")
W = T.matrix("W")
b_sym = T.vector("b_sym")

results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function(inputs=[X, W, b_sym], outputs=[results])

# test values
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2
print compute_elementwise(x, w, b)[0]
# comparison with numpy
print np.tanh(x.dot(w) + b)

输出:

[[ 0.96402758 0.99505475] [ 0.96402758 0.99505475]]
[[ 0.96402758 0.99505475] [ 0.96402758 0.99505475]]

你可能感兴趣的:(理解Theano的Scan函数)