一般处理单个的输入,前一个输入和后一个输入完全无关,但实际应用中,某些任务需要能够更好的处理序列的信息,即前面的输入和后面的输入是有关系的。比如:时间序列问题
在进一步了解RNN之前,先给出最基本的单层网络结构,输入是 x x x,经过变换 W x + b Wx+b Wx+b和激活函数 f f f得到输出 y y y:
RNN在单层网络结构的基础上引入了隐藏层 h h h, h h h可对序列数据提取特征,接着再转换为输出。
注:图中的圆圈表示向量,箭头表示对向量做变换。
RNN中,每个步骤权值共享,使用的参数 U , W , b U,W,b U,W,b相同, h 2 h_2 h2的计算方式和 h 1 h_1 h1类似,其计算结果如下:
接下来,计算RNN的输出 y 1 y_1 y1,采用 S o f t m a x Softmax Softmax作为激活函数,根据 y n = f ( W x + b ) y_n=f(Wx+b) yn=f(Wx+b),得 y 1 y_1 y1:
使用和 y 1 y_1 y1相同的参数 V , c V,c V,c,得到 y 1 , y 2 , y 3 , y 4 y_1,y_2,y_3,y_4 y1,y2,y3,y4的输出结构:
类别 | 特点描述 |
---|---|
相同点 | 1、传统神经网络的扩展。 2、前向计算产生结果,反向计算模型更新。 3、每层神经网络横向可以多个神经元共存,纵向可以有多层神经网络连接。 |
不同点 | 1、CNN空间扩展,神经元与特征卷积;RNN时间扩展,神经元与多个时间输出计算 2、RNN可以用于描述时间上连续状态的输出,有记忆功能,CNN用于静态输出 |
相同点:
不同点:
由于RNN特有的memory会影响后期其他的RNN的特点,梯度时大时小,learning rate没法个性化的调整,导致RNN在train的过程中,Loss会震荡起伏。为了解决RNN的这个问题,在训练的时候,可以设置临界值,当梯度大于某个临界值,直接截断,用这个临界值作为梯度的大小,防止大幅震荡。
以 x x x表示输入, h h h是隐层单元, o o o是输出, L L L为损失函数, y y y为训练集标签。 t t t表示 t t t时刻的状态, V , U , W V,U,W V,U,W是权值,同一类型的连接权值相同。以下图为例进行说明标准RNN的前向传播算法:
对于 t t t时刻:
h ( t ) = ϕ ( U x ( t ) + W h ( t − 1 ) + b ) h^{(t)}=\phi(Ux^{(t)}+Wh^{(t-1)}+b) h(t)=ϕ(Ux(t)+Wh(t−1)+b)
其中 ϕ ( ) \phi() ϕ()为激活函数,一般会选择tanh函数, b b b为偏置。
t t t时刻的输出为:
o ( t ) = V h ( t ) + c o^{(t)}=Vh^{(t)}+c o(t)=Vh(t)+c
模型的预测输出为:
y ^ ( t ) = σ ( o ( t ) ) \widehat{y}^{(t)}=\sigma(o^{(t)}) y (t)=σ(o(t))
其中 σ \sigma σ为激活函数,通常RNN用于分类,故这里一般用softmax函数。
原因
如何解决
LSTM 拥有三个门,分别是遗忘门,输入门和输出门,来保护和控制细胞状态。
忘记门
作用对象:细胞状态 。
作用:将细胞状态中的信息选择性的遗忘。
操作步骤:该门会读取 h t − 1 h_{t-1} ht−1和 x t x_t xt,输出一个在 0 到 1 之间的数值给每个在细胞状态 C t − 1 C_{t-1} Ct−1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。
输入门
作用对象:细胞状态
作用:将新的信息选择性的记录到细胞状态中。
操作步骤:
sigmoid 层称 “输入门” 决定什么值我们将要更新。
tanh 层创建一个新的候选值向量 C ~ t \tilde{C}_t C~t加入到状态中。
输出层门
作用对象:隐层 h t h_t ht
作用:确定输出什么值
操作步骤:
通过sigmoid 层来确定细胞状态的哪个部分将输出。
把细胞状态通过 tanh 进行处理,并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。
LSTMs与GRUs的区别如图所示:
从上图可以看出,二者结构十分相似,不同在于:
new memory都是根据之前state及input进行计算,但是GRUs中有一个reset gate控制之前state的进入量,而在LSTMs里没有类似gate;
产生新的state的方式不同,LSTMs有两个不同的gate,分别是forget gate (f gate)和input gate(i gate),而GRUs只有一种update gate(z gate);
LSTMs对新产生的state可以通过output gate(o gate)进行调节,而GRUs对输出无任何调节。
BPTT(back-propagation through time)算法是常用的训练RNN的方法,其本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。需要寻优的参数有三个,分别是 U 、 V 、 W U、V、W U、V、W。与BP算法不同的是,其中 W W W和 U U U两个参数的寻优过程需要追溯之前的历史数据,参数 V V V相对简单只需关注目前,先求解参数V的偏导数。
∂ L ( t ) ∂ V = ∂ L ( t ) ∂ o ( t ) ⋅ ∂ o ( t ) ∂ V \frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V} ∂V∂L(t)=∂o(t)∂L(t)⋅∂V∂o(t)
RNN的损失也是会随着时间累加的,所以不能只求t时刻的偏导。
L = ∑ t = 1 n L ( t ) L=\sum_{t=1}^{n}L^{(t)} L=t=1∑nL(t)
∂ L ∂ V = ∑ t = 1 n ∂ L ( t ) ∂ o ( t ) ⋅ ∂ o ( t ) ∂ V \frac{\partial L}{\partial V}=\sum_{t=1}^{n}\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V} ∂V∂L=t=1∑n∂o(t)∂L(t)⋅∂V∂o(t)
W W W和 U U U的偏导的求解由于需要涉及到历史数据,其偏导求起来相对复杂。为了简化推导过程,假设只有三个时刻,那么在第三个时刻 L L L对 W W W, L L L对 U U U的偏导数分别为:
∂ L ( 3 ) ∂ W = ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ W + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ W + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ h ( 1 ) ∂ h ( 1 ) ∂ W \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W} ∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂W∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂W∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂W∂h(1)
∂ L ( 3 ) ∂ U = ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ U + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ U + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ h ( 1 ) ∂ h ( 1 ) ∂ U \frac{\partial L^{(3)}}{\partial U}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U} ∂U∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂U∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂U∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂U∂h(1)
可以观察到,在某个时刻的对 W W W或是 U U U的偏导数,需要追溯这个时刻之前所有时刻的信息。根据上面两个式子得出L在 t t t时刻对 W W W和 U U U偏导数的通式:
∂ L ( t ) ∂ W = ∑ k = 0 t ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ h ( t ) ( ∏ j = k + 1 t ∂ h ( j ) ∂ h ( j − 1 ) ) ∂ h ( k ) ∂ W \frac{\partial L^{(t)}}{\partial W}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial W} ∂W∂L(t)=k=0∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂W∂h(k)
∂ L ( t ) ∂ U = ∑ k = 0 t ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ h ( t ) ( ∏ j = k + 1 t ∂ h ( j ) ∂ h ( j − 1 ) ) ∂ h ( k ) ∂ U \frac{\partial L^{(t)}}{\partial U}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial U} ∂U∂L(t)=k=0∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂U∂h(k)
整体的偏导公式就是将其按时刻再一一加起来。