本文假设读者已经熟悉了常规的神经网络,并且了解了BP算法,如果还不了解的,参见UFIDL的教程。
- 1.RNN结构
- 2.符号定义
- 3.网络unrolled及公式推导
- 4.BPTT
- 5.RTRL
- 6.Hybrid(FP/BPTT)
- 7.参考文献
如下图1是一个最简单的RNN:
其中集合 I 为 m 个外部输入节点,集合 U 为前一时刻的隐层输入节点,U中的节点数为 n ,并假定U中所有节点的输出都参与到下一时刻的输入。
定义:
xi(t) : t 时刻第 i 个输入节点的输出值,且 i∈I∪U
sk(t) : t 时刻第 k 个隐层节点的输出值,且 k∈U
yk(t) : t 时刻第 k 个隐层节点的输出值,且 k∈U
dk(t) : t 时刻隐层第 k 个节点的期望输出(即训练数据)
wli :第 i 个输入到第 l 个隐层节点的权重,其中 i∈I,l∈U
wlk :第 k 个输入到第 l 个隐层节点的权重,其中 k,l∈U
τ :假定网络的起始时刻为 t0 ,当前时刻为 t , t′∈[t0,t) , τ∈(t′,t]
y∗k(τ) : τ 时刻第 k 个输出节点的输出值,且 k∈U,且τ∈(t0,t] ,对于所有的 τ 而言,其实有 yk(τ)=y∗k(τ) ,这里之所以引入新符号,是为了避免求导运算时混淆1。
再来是一组等式定义:
sk(τ+1)=wx(τ)
ek(t)=dk(t)−yk(t)
J(τ)=∑k∈Uek(t)
Jtotal(t′,t)=∑τ=t′+1tJ(τ),t′∈[t0,t)
ϵk(τ;F)=∂F∂yk(τ)
ek(τ;F)=∂F∂y∗k(τ)
δk(τ;F)=∂F∂sk(τ)
pkij(τ)=∂yk(τ)∂wij
因为假定 F 只与 yk(τ),τ∈(t′,t] 显式相关,所以,当 τ≤t′ 时, ek(τ;F)=0 。
由于 F 是任意与 yk(t) 相关的函数,实际应用中,可以取
F=J(τ);F=Jtotal(t′,t) 或其它函数。
因为初始状态的输出 yk(t0) 为预设值,与 w 之间不存在函数关系,所以当 τ=t0 时, pkij(t0)=0 。
将网络按时间展开:
根据上图,下面两个式子成立:
sk(t+1)=∑l∈Uwklyl(t)+∑l∈Iwklxnetl(t)=∑l∈U∪Iwklxl(t)......(2)
yk(t+1)=fk(sk(t+1))......(3)
显然, y∗k(τ+1),y∗k(τ+2),...,y∗k(t) 可以表示成 s(τ+1) 的函数,因此,
F=F(y∗(t′),y∗(t′+1),...,yk(τ),s(τ+1))=F
下面对公式进行进一步的推导:
ϵk(τ;F)=∂F∂yk(τ)
=∂F(y∗(t′),y∗(t′+1),...,yk(τ),s(τ+1))∂yk(τ)
由复合函数求导法则,上式可进一步变为:
∂F∂y(t′)∂y(t′)∂yk(τ)+∂F∂y(t′+1)∂y(t′+1)∂yk(τ)+...+∂F∂y∗(τ)∂y∗(τ)∂yk(τ)+∂F∂s(τ+1)∂s(τ+1)∂yk(τ)
当 τ′<τ 时,显然 y(τ′) 与 y(τ) 无关,故上式的前半部分为0,即:
ϵk(τ;F)=∂F∂y∗(τ)∂y∗(τ)∂yk(τ)+∂F∂s(τ+1)∂s(τ+1)∂yk(τ)
这里:
∂F∂y∗(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂y∗1(τ)∂F∂y∗2(τ)...∂F∂y∗k(τ)...∂F∂y∗n(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥
∂y∗(τ)∂yk(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂y∗1(τ)∂y∗k(τ)∂y∗2(τ)∂y∗k(τ)...∂y∗k(τ)∂y∗k(τ)...∂y∗n(τ)∂y∗k(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢00...1...0⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥
∂F∂s(τ+1)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂s1(τ+1)∂F∂s2(τ+1)...∂F∂sl(τ+1)...∂F∂sn(τ+1)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥
∂s(τ+1)∂yk(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂s∗1(τ+1)∂y∗k(τ)∂s∗2(τ+1)∂y∗k(τ)...∂s∗l(τ+1)∂y∗k(τ)...∂s∗n(τ+1)∂y∗k(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢w1kw2k...wlk...wnk⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥
代入,上式可以变为:
ϵk(τ;F)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂y∗1(τ)∂F∂y∗2(τ)...∂F∂y∗k(τ)...∂F∂y∗n(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥T⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢00...1...0⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥+⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥T⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢w1kw2k...wlk...wnk⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥=∂F∂y∗k(τ)+∑l∈Uwlkδl(τ+1;F)
所以就有:
ϵk(τ;F)=∂F∂y∗k(τ)+∑l∈Uwlkδl(τ+1;F)=ek(τ;F)+∑l∈Uwlkδl(τ+1;F)
因为当 τ=t 时, ϵk(t;F)=ek(t;F) ,所以有:
δk(τ;F)=∂F∂sk(τ)=∂F∂yk(τ)∂yk(τ)∂sk(τ)=ϵk(τ;F)f′k(sk(τ))
进一步推导:
ϵk(τ;F)=(ek(τ;F)+∑l∈Uwlkδl(τ+1;F))f′k(sk(τ))
先做如下定义:
wij :第 j 个输入到第 i 个隐层节点的权重(迭代更新之前),其中 i∈U,j∈U∪I
wij(τ) : τ 时刻第 j 个输入到第 i 个隐层节点的权重(迭代更新之前),其中 τ∈[t0,t),i∈U,j∈U∪I
∂F∂wij(τ)=∂F∂si(τ+1)∂si(τ+1)∂wij(τ)=δi(τ+1;F)xj(τ)
∂F∂wij=∑τ=t0t−1∂F∂wij(τ)∂wij(τ)∂wij=∑τ=t0t−1∂F∂wij(τ)=∑τ=t0t−1δi(τ+1;F)xj(τ)
算法描述:
令 τ∈(t0,t],k∈U ,
ϵk(t)=ek(t),
δk(τ)=f′k(sk(τ))ϵk(τ),
ϵk(τ−1)=∑l∈Uwlkδl(τ),
可以看出,算法的公式与BP算法非常相似,算法从t时刻开始,先用等式 ϵk(t)=ek(t) 求出 ϵk(t) ,然后再用后边两个等式继续向后迭代,直到 t0 。这里的第一步也被称为错误注入(injecting error),也说是在t时刻注入了 ek(t) 。
上图描述了Real-Time BPTT算法在每一个时刻t的存储和处理操作。历史缓存每经过一个时刻t,就会增加一层的数据(包括该t时刻所有的输入和输出值)。实线箭头表明了当前的输出值由和上一时刻的输入输出值确定。虚线表示反向传播,计算直到 t0+1 的 δ 。步骤①为injecting error操作,剩下的步骤为每一步的误差计算。
激活函数通常取logistics函数,此时的 f′k(sk(τ))=fk(sk(τ))(1−fk(sk(τ)))
最后,权值的梯度通过下式计算:
∂J(t)∂wij=∑τ=t0+1tδi(τ)xj(τ−1)
在每一个时刻t,算法的执行流程如下:
(1)将当前网络的状态和当前的输入值添加到历史缓存2;
(2)注入当前时刻 t 的 ek(t) ,然后在时间区间 (t0,t] 上进行反向传播,计算出所有的 ϵk(τ),δk(τ) ;
(3)计算所有的 ∂J(t)∂wij ;
(4)根据第(3)步的结果修改权值。
随着时间的增长,算法对历史缓存的需求将是无限的,因此,有时也用BPTT(∞)来表示这个算法,它在理论上的研究价值要远大于实用。接下来,我们将讨论更为实用的近似算法。
为了解决Real-Time BPTT对内存的无限制需求,我们采用一种近似的算法,即:Epochwise BPTT。
算法的目标是计算基于 Jtotal(t0,t1) 的梯度(即损失函数 F=Jtotal(t0,t1) ),其步骤跟前边类似。同样的,
令 τ∈(t0,t1],k∈U ,
ϵk(t1)=ek(t1),
δk(τ)=f′k(sk(τ))ϵk(τ),
ϵk(τ−1)=ek(τ−1)+∑l∈Uwlkδl(τ),
算法从最后的时刻 t1 开始,injecting error ek(t1) ,然后运用后边两个等式,迭代计算 δk(τ),ϵk(τ−1) ,直到 τ=t0+1 。此时权值的梯度按下式计算:
∂Jtotal(t0,t1)∂wij=∑τ=t0+1t1δi(τ)xj(τ−1)
对 [t0,t1] 中所有的输入输出以及目标值都被存储在历史缓存中。实线表示输出由上一时刻的输入和输出确定,当一次epoch完成后,执行BP操作(虚线箭头)。奇数索引的步骤表示error injection,偶数索引的步骤表示误差( δ )传播。一旦BP操作完成,每个权值的梯度就可以算出来了。
算法的执行流程如下:
(1)执行BP算法,计算所有的 ϵk(τ),δk(τ),τ∈(t0,t1] ;
(2)计算所有的 ∂Jtotal(t0,t1)∂wij ;
(3)使用(2)的结果更新权值,重复步骤(1)~(3);
与反向传播的BPTT算法不同的是,RTRL通过前向传播梯度来进行计算。
对任意的 k∈U,i∈U,j∈U∪I,以及t∈[t0,t1] ,定义:
pkij(t)=∂yk(t)∂wij
令 F=J(t) ,有:
∂J(t)∂wij=∑k∈Uek(t)pkij(t)
根据之前的关系等式:
sk(t+1)=∑l∈Uwklyl(t)+∑l∈Iwklxnetl(t)=∑l∈U∪Iwklxl(t)......(2)
yk(t+1)=fk(sk(t+1))......(3)
可以推出:
pkij(t+1)=∂yk(t+1)∂wij=∂yk(t+1)∂sk(t+1)∂sk(t+1)∂wij=f′k(sk(t+1))[∑l∈Uwklplij(t)+δikxj(t)] 3
此外, t0 时刻的输出为预设值,与连接权值无关,所以有:
pkij(t0)=∂yk(t0)∂wij=0
于是,整个计算过程将从 t=t0 开始迭代计算,直到 t=t1 。
对每一个时刻 t ,计算相应的 yk(t) 以及 ∂J(t)∂wij
∂F∂wij=∑τ=t0t′−1∂F∂wij(τ)+∑τ=t′t−1∂F∂wij(τ)
等式右边的第一部分可写为:
∑τ=t0t′−1∂F∂wij(τ)=∑τ=t0t′−1∑l∈U∂F∂yl(t′)∂yl(t′)∂wij(τ)=∑l∈U∂F∂yl(t′)∑τ=t0t′−1∂yl(t′)∂wij(τ)=∑l∈U∂F∂yl(t′)∂yl(t′)∂wij=∑l∈Uϵl(t′;F)plij(t′)
因此,最初的式子可变为:
∂F∂wij=∑l∈Uϵl(t′;F)plij(t′)+∑τ=t′t−1δi(τ+1;F)xj(τ)
令 F=Jtotal(t′,t)
∂Jtotal(t′,t)∂wij=∑l∈Uϵl(t′)plij(t′)+∑τ=t′t−1δi(τ+1)xj(τ)
首先计算BPTT:
然后,使用上边的计算结果执行:
prij(t)=∑l∈Uϵl(t′)plij(t′)+∑τ=t′t−1δl(τ+1)xj(τ)
上图是FP/BPTT(h)算法的简单描述。可以看到,算法包含两个连续的误差计算过程。一个在时刻 t ,另一个在时刻 t+h .从时刻 t−h 直到时刻 t 的输入、输出和目标值都存储在历史缓存中。
1.Gradient-Based Learning Algorithms for Recurrent Networks and Their Computational Complexity.Ronald J. Williams,David Zipser